# Same thing, but torch manual_seed instead of torch generators 
-> Perfectly reproducible!

In [1]:
import os
print(os.getcwd())

# os.chdir("/home/u699081/FOLDER_comdo")
os.chdir("/home/andrei/Desktop/PROJECT_ELLIS_COMDO/FOLDER_code")
import equinox as eqx
import jax
import jax.numpy as jnp
# import optax  # https://github.com/deepmind/optax
import torch  # https://pytorch.org
import torchvision  # https://pytorch.org
from jaxtyping import Array, Float, Int, PyTree  # https://github.com/google/jaxtyping
import sklearn
from sklearn.model_selection import train_test_split
# import seaborn as sns
import matplotlib.pyplot as plt
import copy
from torch.utils.tensorboard import SummaryWriter
import os

import time
import comdo
import pickle



# _________________________ Utils 

class CNN(eqx.Module):
    layers: list

    def __init__(self, key):
        key1, key2, key3, key4 = jax.random.split(key, 4)
        # Standard CNN setup: convolutional layer, followed by flattening,
        # with a small MLP on top.
        self.layers = [
            eqx.nn.Conv2d(1, 3, kernel_size=4, key=key1),
            eqx.nn.MaxPool2d(kernel_size=2),
            jax.nn.relu,
            jnp.ravel,
            eqx.nn.Linear(1728, 512, key=key2),
            jax.nn.sigmoid,
            eqx.nn.Linear(512, 64, key=key3),
            jax.nn.relu,
            eqx.nn.Linear(64, 10, key=key4),
            jax.nn.log_softmax,
        ]

    def __call__(self, x: Float[Array, "1 28 28"]) -> Float[Array, "10"]:
        for layer in self.layers:
            x = layer(x)
        return x




def loss(
    model: CNN, x: Float[Array, "batch 1 28 28"], y: Int[Array, " batch"]
) -> Float[Array, ""]:
    # Our input has the shape (BATCH_SIZE, 1, 28, 28), but our model operations on
    # a single input input image of shape (1, 28, 28).
    #
    # Therefore, we have to use jax.vmap, which in this case maps our model over the
    # leading (batch) axis.
    pred_y = jax.vmap(model)(x)
    return cross_entropy(y, pred_y)


def cross_entropy(
    y: Int[Array, " batch"], pred_y: Float[Array, "batch 10"]
) -> Float[Array, ""]:
    # y are the true targets, and should be integers 0-9.
    # pred_y are the log-softmax'd predictions.
    pred_y = jnp.take_along_axis(pred_y, jnp.expand_dims(y, 1), axis=1)
    return -jnp.mean(pred_y)


loss = eqx.filter_jit(loss)  # JIT our loss function from earlier!


@eqx.filter_jit
def compute_accuracy(
    model: CNN, x: Float[Array, "batch 1 28 28"], y: Int[Array, " batch"]
) -> Float[Array, ""]:
    """This function takes as input the current model
    and computes the average accuracy on a batch.
    """
    pred_y = jax.vmap(model)(x)
    pred_y = jnp.argmax(pred_y, axis=1)
    return jnp.mean(y == pred_y)


def evaluate(model: CNN, testloader: torch.utils.data.DataLoader):
    """This function evaluates the model on the test dataset,
    computing both the average loss and the average accuracy.
    """
    avg_loss = 0
    avg_acc = 0
    for x, y in testloader:
        x = x.numpy()
        y = y.numpy()
        # Note that all the JAX operations happen inside `loss` and `compute_accuracy`,
        # and both have JIT wrappers, so this is fast.
        avg_loss += loss(model, x, y)
        avg_acc += compute_accuracy(model, x, y)
    return avg_loss / len(testloader), avg_acc / len(testloader)


# _________________________________________________________________________















# Hyperparameters

# TODO: choose nr. of steps (default = 600)
STEPS = 1800

BATCH_SIZE = 64

LEARNING_RATE = 0.005

PRINT_EVERY = 10




# ________________________ train script __________________-

from comdo.utils_ANNs import DOptimizer

def train(
    optimizer,
    models: CNN,
    trainloader_Agent1: torch.utils.data.DataLoader,
    trainloader_Agent2: torch.utils.data.DataLoader,
    testloader_Agent1: torch.utils.data.DataLoader,
    testloader_Agent2: torch.utils.data.DataLoader,
    steps: int,
    print_every: int,
) -> CNN:
    # Just like earlier: It only makes sense to train the arrays in our model,
    # so filter out everything else.
    # opt_state = optim.init(eqx.filter(model, eqx.is_array))



    # Always wrap everything -- computing gradients, running the optimiser, updating
    # the model -- into a single JIT region. This ensures things run as fast as
    # possible.


    # @eqx.filter_jit
    def make_step(
        models: CNN,
        x_agent1: Float[Array, "batch 1 28 28"],
        y_agent1: Int[Array, " batch"],
        x_agent2: Float[Array, "batch 1 28 28"],
        y_agent2: Int[Array, " batch"],

    ):

        # making list with the grads (pytrees) of th etwo agents
        grads_list = []
        grads_list.append(eqx.filter_value_and_grad(loss)(models[0], x_agent1, y_agent1)[1])
        grads_list.append(eqx.filter_value_and_grad(loss)(models[1], x_agent2, y_agent2)[1])

        # loss_value, grads = eqx.filter_value_and_grad(loss)(model, x, y)
        # updates, opt_state = optim.update(grads, opt_state, model)

        # model = eqx.apply_updates(model, updates)
        
        models = optimizer.step_withMemory(models, grads_list)

        return models
    
    # Loop over our training dataset as many times as we need.
    def infinite_trainloader_agent1():
        while True:
            yield from trainloader_Agent1

    def infinite_trainloader_agent2():
        while True: 
            yield from trainloader_Agent2


    writer = SummaryWriter()

    list_test_accuracy = []
    list_test_loss =[]
    list_train_accuracy = []
    list_train_loss =[]

    for step, (x_agent1, y_agent1), (x_agent2, y_agent2) in zip(range(steps), infinite_trainloader_agent1(), infinite_trainloader_agent2()):
        # PyTorch dataloaders give PyTorch tensors by default,
        # so convert them to NumPy arrays.
        x_agent1 = x_agent1.numpy()
        y_agent1 = y_agent1.numpy()

        x_agent2 = x_agent2.numpy()
        y_agent2 = y_agent2.numpy()
        
        # start = time.time()
        

        # initial allignent, in case of initial difference in states between agents

        if step == 0:
            for layer_i in optimizer.idx_layersWithWeights:

                    aux_update = (1/optimizer.n_agents) * sum( models[agent_j].layers[layer_i].weight for agent_j in range(optimizer.n_agents) )
                    
                    for agent_i in range(optimizer.n_agents):
                        where = lambda m: m[agent_i].layers[layer_i].weight
                        models = eqx.tree_at(where, models, aux_update)


        models = make_step(models, x_agent1, y_agent1, x_agent2, y_agent2)
            
        # print("Made step ", step )
        # print("Seconds to take this step: ", time.time() - start)

        # print("Z_g")
        # print(optimizer.z_g)

        if (step % print_every) == 0 or (step == steps - 1):

            test_loss = 0
            test_accuracy = 0
            train_loss = 0
            train_accuracy = 0

            # start = time.time()

            test_loss +=  float(evaluate(models[0], testloader_Agent1)[0])
            test_accuracy +=  float(evaluate(models[0], testloader_Agent1)[1])
            
            test_loss += float(evaluate(models[1], testloader_Agent2)[0])
            test_accuracy += float(evaluate(models[1], testloader_Agent2)[1])

            train_loss +=  float(evaluate(models[0], trainloader_Agent1)[0])
            train_accuracy += float(evaluate(models[0], trainloader_Agent1)[1])
            train_loss += float(evaluate(models[1], trainloader_Agent2)[0])
            train_accuracy += float(evaluate(models[1], trainloader_Agent2)[1])


            # print("Seconds to compute loss and accuracy: ", time.time() - start)

            # start = time.time()

            writer.add_scalar("global train loss", train_loss/2, step)   # printing means
            writer.add_scalar("global test loss", test_loss/2, step)

            writer.add_scalar("global train accuracy", train_accuracy/2, step)
            writer.add_scalar("global test accuracy", test_accuracy/2, step)

            # ______ appending performance to lists ______________-

            list_test_accuracy.append(test_accuracy/2) 
            list_test_loss.append(test_loss/2)
            list_train_accuracy.append(train_accuracy/2) 
            list_train_loss.append(train_loss/2)

            # if (step % print_every) == 0 or (step == steps - 1):
            print(
                f"global train_loss={train_loss/2}, gloabal train_accuracy={train_accuracy/2} "
                f"global test_loss={test_loss/2}, global test_accuracy={test_accuracy/2}" )
            
            # print("Seconds to write and print: ", time.time() - start)

    list_test_Accuracies_acrossInitializations.append(list_test_accuracy)
    list_test_Loss_acrossInitializations.append(list_test_loss)
    list_train_Accuracies_acrossInitializations.append(list_train_accuracy)
    list_train_Loss_acrossInitializations.append(list_train_loss)

    return models



# TODO: fill in range with the number of runs you want
SEEDS = range(5)



list_test_Accuracies_acrossInitializations = []
list_test_Loss_acrossInitializations = []
list_train_Accuracies_acrossInitializations = []
list_train_Loss_acrossInitializations = []

for SEED in SEEDS:

    key = jax.random.PRNGKey(SEED)
    key, subkey = jax.random.split(key, 2)

    agent1 = CNN(subkey)
    agent2 = CNN(subkey+100)

    models = [agent1, agent2]

    optimizer_fractional = comdo.utils_ANNs.DOptimizer(models= models,
                                            beta_c = 1,
                                            beta_g = LEARNING_RATE,
                                            beta_gm = LEARNING_RATE / 2)                                        

    torch.manual_seed(SEED)

    normalise_data = torchvision.transforms.Compose(
    [
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5,), (0.5,)),
    ]
    )

    train_dataset = torchvision.datasets.MNIST(
        "MNIST",
        train=True,
        download=True,
        transform=normalise_data,
    )
    test_dataset = torchvision.datasets.MNIST(
        "MNIST",
        train=False,
        download=True,
        transform=normalise_data,
    )
    trainloader = torch.utils.data.DataLoader(
        train_dataset, batch_size=BATCH_SIZE, shuffle=True

    )
    testloader = torch.utils.data.DataLoader(
        test_dataset, batch_size=BATCH_SIZE, shuffle=True
    )

    # ____________________ splitting MNIST into 2 distinct datasets __________________
    # import os
    # os.chdir("/home/andrei/Desktop/PROJECT_ELLIS_COMDO/FOLDER_code")

    from comdo.utils_ANNs import get_2DO_datasets


    Agent1_Train_dataset, Agent1_Test_dataset, Agent2_Train_dataset, Agent2_Test_dataset = \
        get_2DO_datasets(train_dataset= train_dataset, test_dataset= test_dataset, seed= SEED)


    trainloader_Agent1 = torch.utils.data.DataLoader(
        Agent1_Train_dataset, batch_size=BATCH_SIZE, shuffle=True
    )
    testloader_Agent1 = torch.utils.data.DataLoader(
        Agent1_Test_dataset, batch_size=BATCH_SIZE, shuffle=True
    )

    trainloader_Agent2 = torch.utils.data.DataLoader(
        Agent2_Train_dataset, batch_size=BATCH_SIZE, shuffle=True
    )
    testloader_Agent2 = torch.utils.data.DataLoader(
        Agent2_Test_dataset, batch_size=BATCH_SIZE, shuffle=True
    )


    trained_models = train(optimizer= optimizer_fractional, models= models, trainloader_Agent1= trainloader_Agent1, trainloader_Agent2= trainloader_Agent2, testloader_Agent1= testloader_Agent1, testloader_Agent2= testloader_Agent2, steps= STEPS, print_every= PRINT_EVERY)


# TODO change results path

with open('/home/u699081/FOLDER_comdo/SIMULATIONS/RESULTS/ANNs/Fractional/test_acc.pkl', 'wb') as f:
    pickle.dump(list_test_Accuracies_acrossInitializations, f)

with open('/home/u699081/FOLDER_comdo/SIMULATIONS/RESULTS/ANNs/Fractional/test_loss.pkl', 'wb') as f:
    pickle.dump(list_test_Loss_acrossInitializations, f)

with open('/home/u699081/FOLDER_comdo/SIMULATIONS/RESULTS/ANNs/Fractional/train_acc.pkl', 'wb') as f:
    pickle.dump(list_train_Accuracies_acrossInitializations, f)

with open('/home/u699081/FOLDER_comdo/SIMULATIONS/RESULTS/ANNs/Fractional/train_loss.pkl', 'wb') as f:
    pickle.dump(list_train_Loss_acrossInitializations, f)

/home/andrei/Desktop/PROJECT_ELLIS_COMDO/FOLDER_code
global train_loss=2.305303692817688, gloabal train_accuracy=0.09670983627438545 global test_loss=2.3039047718048096, global test_accuracy=0.09661788120865822
global train_loss=2.1893906593322754, gloabal train_accuracy=0.2486577033996582 global test_loss=2.1851030588150024, global test_accuracy=0.2434731051325798
global train_loss=2.080592393875122, gloabal train_accuracy=0.5451133847236633 global test_loss=2.0726053714752197, global test_accuracy=0.552808552980423
global train_loss=1.9616005420684814, gloabal train_accuracy=0.5400357097387314 global test_loss=1.950952410697937, global test_accuracy=0.5515229254961014
global train_loss=1.8147075772285461, gloabal train_accuracy=0.6582254469394684 global test_loss=1.803224503993988, global test_accuracy=0.6759295761585236
global train_loss=1.6652057766914368, gloabal train_accuracy=0.6330743432044983 global test_loss=1.6486510634422302, global test_accuracy=0.6481408178806305


KeyboardInterrupt: 