# Training a basic ANN on MNIST

In [1]:
import equinox as eqx
import jax
import jax.numpy as jnp
import optax  # https://github.com/deepmind/optax
import torch  # https://pytorch.org
import torchvision  # https://pytorch.org
from jaxtyping import Array, Float, Int, PyTree  # https://github.com/google/jaxtyping
import sklearn
from sklearn.model_selection import train_test_split
import seaborn as sns
import matplotlib.pyplot as plt
import copy
from torch.utils.tensorboard import SummaryWriter
import os

os.chdir("/home/andrei/Desktop/PROJECT_ELLIS_COMDO/FOLDER_code")
import time
import comdo

In [2]:
# Hyperparameters


BATCH_SIZE = 64

LEARNING_RATE = 0.005

STEPS = 300

PRINT_EVERY = 10

SEED = 5678

key = jax.random.PRNGKey(SEED)

### Loading dataset

In [3]:
normalise_data = torchvision.transforms.Compose(
    [
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5,), (0.5,)),
    ]
)
train_dataset = torchvision.datasets.MNIST(
    "MNIST",
    train=True,
    download=True,
    transform=normalise_data,
)
test_dataset = torchvision.datasets.MNIST(
    "MNIST",
    train=False,
    download=True,
    transform=normalise_data,
)
trainloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True
)
testloader = torch.utils.data.DataLoader(
    test_dataset, batch_size=BATCH_SIZE, shuffle=True
)

### Splitting MNIST into 2 balanced, distinct datasets

In [4]:

import os
os.chdir("/home/andrei/Desktop/PROJECT_ELLIS_COMDO/FOLDER_code")

from comdo.utils_ANNs import get_2DO_datasets


Agent1_Train_dataset, Agent1_Test_dataset, Agent2_Train_dataset, Agent2_Test_dataset = \
    get_2DO_datasets(train_dataset= train_dataset, test_dataset= test_dataset)


trainloader_Agent1 = torch.utils.data.DataLoader(
    Agent1_Train_dataset, batch_size=BATCH_SIZE, shuffle=True
)
testloader_Agent1 = torch.utils.data.DataLoader(
    Agent1_Test_dataset, batch_size=BATCH_SIZE, shuffle=True
)

trainloader_Agent2 = torch.utils.data.DataLoader(
    Agent2_Train_dataset, batch_size=BATCH_SIZE, shuffle=True
)
testloader_Agent2 = torch.utils.data.DataLoader(
    Agent2_Test_dataset, batch_size=BATCH_SIZE, shuffle=True
)


## Making 2 ANNs with the exact same initialization 

In [5]:
class CNN(eqx.Module):
    layers: list

    def __init__(self, key):
        key1, key2, key3, key4 = jax.random.split(key, 4)
        # Standard CNN setup: convolutional layer, followed by flattening,
        # with a small MLP on top.
        self.layers = [
            eqx.nn.Conv2d(1, 3, kernel_size=4, key=key1),
            eqx.nn.MaxPool2d(kernel_size=2),
            jax.nn.relu,
            jnp.ravel,
            eqx.nn.Linear(1728, 512, key=key2),
            jax.nn.sigmoid,
            eqx.nn.Linear(512, 64, key=key3),
            jax.nn.relu,
            eqx.nn.Linear(64, 10, key=key4),
            jax.nn.log_softmax,
        ]

    def __call__(self, x: Float[Array, "1 28 28"]) -> Float[Array, "10"]:
        for layer in self.layers:
            x = layer(x)
        return x

SEED = 5678
key = jax.random.PRNGKey(SEED)
key, subkey = jax.random.split(key, 2)

agent1 = CNN(subkey)
agent2 = CNN(subkey)

models = [agent1, agent2]

## Inspecting the PyTree - The weights should be equal





In [7]:
print("Do agent 1 and 2 have the same weights in the first layer?")
print(agent1.layers[0].weight.all() == agent2.layers[0].weight.all())

print(agent1.layers[0].weight[0]) # I expect them to be equal
print(agent2.layers[0].weight[0])



Do agent 1 and 2 have the same weights in the first layer?
True
[[[-0.0246132   0.20315117 -0.12337857  0.1912669 ]
  [ 0.01224852  0.03098577 -0.17678964  0.18533075]
  [-0.00282699 -0.12770635 -0.10529053 -0.24286664]
  [-0.05992258  0.18098432 -0.22828996  0.21605003]]]
[[[-0.0246132   0.20315117 -0.12337857  0.1912669 ]
  [ 0.01224852  0.03098577 -0.17678964  0.18533075]
  [-0.00282699 -0.12770635 -0.10529053 -0.24286664]
  [-0.05992258  0.18098432 -0.22828996  0.21605003]]]


In [6]:
def loss(
    model: CNN, x: Float[Array, "batch 1 28 28"], y: Int[Array, " batch"]
) -> Float[Array, ""]:
    # Our input has the shape (BATCH_SIZE, 1, 28, 28), but our model operations on
    # a single input input image of shape (1, 28, 28).
    #
    # Therefore, we have to use jax.vmap, which in this case maps our model over the
    # leading (batch) axis.
    pred_y = jax.vmap(model)(x)
    return cross_entropy(y, pred_y)


def cross_entropy(
    y: Int[Array, " batch"], pred_y: Float[Array, "batch 10"]
) -> Float[Array, ""]:
    # y are the true targets, and should be integers 0-9.
    # pred_y are the log-softmax'd predictions.
    pred_y = jnp.take_along_axis(pred_y, jnp.expand_dims(y, 1), axis=1)
    return -jnp.mean(pred_y)

dummy_x, dummy_y = next(iter(trainloader_Agent1))

dummy_x = jnp.array(dummy_x)
dummy_y = jnp.array(dummy_y)

# Example loss
loss_value = loss(agent1, dummy_x, dummy_y)
print(loss_value.shape)  # scalar loss
# Example inference
output = jax.vmap(agent1)(dummy_x)

print(output.shape)  # batch of predictions

()
(64, 10)


In [9]:
# Getting the parameters

value, grads = eqx.filter_value_and_grad(loss)(agent1, dummy_x, dummy_y)
print(value)


2.3517985


# Evaluation

In [7]:
loss = eqx.filter_jit(loss)  # JIT our loss function from earlier!


@eqx.filter_jit
def compute_accuracy(
    model: CNN, x: Float[Array, "batch 1 28 28"], y: Int[Array, " batch"]
) -> Float[Array, ""]:
    """This function takes as input the current model
    and computes the average accuracy on a batch.
    """
    pred_y = jax.vmap(model)(x)
    pred_y = jnp.argmax(pred_y, axis=1)
    return jnp.mean(y == pred_y)


In [8]:
def evaluate(model: CNN, testloader: torch.utils.data.DataLoader):
    """This function evaluates the model on the test dataset,
    computing both the average loss and the average accuracy.
    """
    avg_loss = 0
    avg_acc = 0
    for x, y in testloader:
        x = x.numpy()
        y = y.numpy()
        # Note that all the JAX operations happen inside `loss` and `compute_accuracy`,
        # and both have JIT wrappers, so this is fast.
        avg_loss += loss(model, x, y)
        avg_acc += compute_accuracy(model, x, y)
    return avg_loss / len(testloader), avg_acc / len(testloader)


In [12]:
evaluate(agent1, testloader_Agent1)


(Array(2.3291311, dtype=float32), Array(0.08623417, dtype=float32))

# Training

In [13]:
print(float(evaluate(models[0], testloader_Agent1)[0]))
print(print(type(float(evaluate(models[0], testloader_Agent1)[0]))))

2.3293533325195312
<class 'float'>
None


In [11]:

from comdo.utils_ANNs import DOptimizer

def train(
    models: CNN,
    trainloader_Agent1: torch.utils.data.DataLoader,
    trainloader_Agent2: torch.utils.data.DataLoader,
    testloader_Agent1: torch.utils.data.DataLoader,
    testloader_Agent2: torch.utils.data.DataLoader,
    steps: int,
    print_every: int,
) -> CNN:
    # Just like earlier: It only makes sense to train the arrays in our model,
    # so filter out everything else.
    # opt_state = optim.init(eqx.filter(model, eqx.is_array))

    optimizer = comdo.utils_ANNs.DOptimizer(models= models,
                                            beta_c = LEARNING_RATE,
                                            beta_g = LEARNING_RATE,
                                            beta_gm = LEARNING_RATE / 2)                                        

    # # Uncomment in case this doesn't work when inside class    
    # optimizer.shape_z_g
    # optimizer.shape_gradient_memory
    # optimizer.gradient_memory
    
    # optimizer.shape_gradient_memory = [optimizer.n_agents, optimizer.len_memory]
    # optimizer.shape_z_g = [optimizer.n_agents]
    # optimizer.idx_layersWithWeights = []

    # for i in range(len(models[0].layers)):
    #     if hasattr(models[0].layers[i], "weight"):
    #         optimizer.idx_layersWithWeights.append(i)
    #         for dim in jnp.shape(models[0].layers[i]):
    #             optimizer.shape_gradient_memory.append(dim)
    #             optimizer.z_g.append(dim)

    print()

    # Always wrap everything -- computing gradients, running the optimiser, updating
    # the model -- into a single JIT region. This ensures things run as fast as
    # possible.


    # @eqx.filter_jit
    def make_step(
        models: CNN,
        x_agent1: Float[Array, "batch 1 28 28"],
        y_agent1: Int[Array, " batch"],
        x_agent2: Float[Array, "batch 1 28 28"],
        y_agent2: Int[Array, " batch"],

    ):

        # making list with the grads (pytrees) of th etwo agents
        grads_list = []
        grads_list.append(eqx.filter_value_and_grad(loss)(models[0], x_agent1, y_agent1)[1])
        grads_list.append(eqx.filter_value_and_grad(loss)(models[1], x_agent2, y_agent2)[1])

        # loss_value, grads = eqx.filter_value_and_grad(loss)(model, x, y)
        # updates, opt_state = optim.update(grads, opt_state, model)

        # model = eqx.apply_updates(model, updates)
        
        models = optimizer.step_withMemory(models, grads_list)

        return models
    
    # Loop over our training dataset as many times as we need.
    def infinite_trainloader_agent1():
        while True:
            yield from trainloader_Agent1

    def infinite_trainloader_agent2():
        while True: 
            yield from trainloader_Agent2


    writer = SummaryWriter()

    for step, (x_agent1, y_agent1), (x_agent2, y_agent2) in zip(range(steps), infinite_trainloader_agent1(), infinite_trainloader_agent2()):
        # PyTorch dataloaders give PyTorch tensors by default,
        # so convert them to NumPy arrays.
        x_agent1 = x_agent1.numpy()
        y_agent1 = y_agent1.numpy()

        x_agent2 = x_agent2.numpy()
        y_agent2 = y_agent2.numpy()
        
        # start = time.time()


        models = make_step(models, x_agent1, y_agent1, x_agent2, y_agent2)
            
        # print("Made step ", step )
        # print("Seconds to take this step: ", time.time() - start)


        if (step % print_every) == 0 or (step == steps - 1):

            test_loss = 0
            test_accuracy = 0
            train_loss = 0
            train_accuracy = 0

            # start = time.time()

            test_loss +=  float(evaluate(models[0], testloader_Agent1)[0])
            test_accuracy +=  float(evaluate(models[0], testloader_Agent1)[1])
            
            test_loss += float(evaluate(models[1], testloader_Agent2)[0])
            test_accuracy += float(evaluate(models[1], testloader_Agent2)[1])

            train_loss +=  float(evaluate(models[0], trainloader_Agent1)[0])
            train_accuracy += float(evaluate(models[0], trainloader_Agent1)[1])
            train_loss += float(evaluate(models[1], trainloader_Agent2)[0])
            train_accuracy += float(evaluate(models[1], trainloader_Agent2)[1])


            # print("Seconds to compute loss and accuracy: ", time.time() - start)

            # start = time.time()

            writer.add_scalar("global train loss", train_loss/2, step)   # printing means
            writer.add_scalar("global test loss", test_loss/2, step)

            writer.add_scalar("global train accuracy", train_accuracy/2, step)
            writer.add_scalar("global test accuracy", test_accuracy/2, step)

            # if (step % print_every) == 0 or (step == steps - 1):
            print(
                f"global train_loss={train_loss/2}, gloabal train_accuracy={train_accuracy/2} "
                f"global test_loss={test_loss/2}, global test_accuracy={test_accuracy/2}" )
            
            
            # print("Seconds to write and print: ", time.time() - start)


    return models

In [12]:
import inspect
import os

model = train(models= models, trainloader_Agent1= trainloader_Agent1, trainloader_Agent2= trainloader_Agent2, testloader_Agent1= testloader_Agent1, testloader_Agent2= testloader_Agent2, steps= STEPS, print_every= PRINT_EVERY)


Made step  0
Seconds to take this step:  0.9320087432861328
global train_loss=2.3305402994155884, gloabal train_accuracy=0.08047902584075928 global test_loss=2.3268001079559326, global test_accuracy=0.08890426903963089
Made step  1
Seconds to take this step:  0.701838493347168
Made step  2
Seconds to take this step:  0.7393503189086914
Made step  3
Seconds to take this step:  0.7125844955444336
Made step  4
Seconds to take this step:  0.710456371307373
Made step  5
Seconds to take this step:  0.7114512920379639
Made step  6
Seconds to take this step:  0.7384293079376221
Made step  7
Seconds to take this step:  0.6841535568237305
Made step  8
Seconds to take this step:  0.6873557567596436
Made step  9
Seconds to take this step:  0.7093634605407715
Made step  10
Seconds to take this step:  0.7276513576507568
global train_loss=2.1663317680358887, gloabal train_accuracy=0.3602879047393799 global test_loss=2.1595379114151, global test_accuracy=0.3799446225166321
Made step  11
Seconds to ta

KeyboardInterrupt: 

# It takes < 1 second to make a step, but 20 seconds to compute metrics, compute metrics more sparse!

# It trains! 

Awesome - I get 0.85 accuracy in 100 iterations;

Slow as fuck - for 300 I need 100 minutes;

### Plan
- put **2h into jit'ing** whatever can be jit'ed
- if unsuccessful -> train only for 50 iterations (15 min) (should get 80% acc. )
    - for 3 algorithms, 10 runs, 2 conditions with 10 runs each, it should take 15h

In [None]:


gradient_term = jnp.zeros(self.shape_z_g)


### Train Centralized

In [None]:
def train(
    model: CNN,
    trainloader: torch.utils.data.DataLoader,
    testloader: torch.utils.data.DataLoader,
    optim: optax.GradientTransformation,
    steps: int,
    print_every: int,
) -> CNN:
    # Just like earlier: It only makes sense to train the arrays in our model,
    # so filter out everything else.
    opt_state = optim.init(eqx.filter(model, eqx.is_array))

    # Always wrap everything -- computing gradients, running the optimiser, updating
    # the model -- into a single JIT region. This ensures things run as fast as
    # possible.
    @eqx.filter_jit
    def make_step(
        model: CNN,
        opt_state: PyTree,
        x: Float[Array, "batch 1 28 28"],
        y: Int[Array, " batch"],
    ):
        loss_value, grads = eqx.filter_value_and_grad(loss)(model, x, y)
        updates, opt_state = optim.update(grads, opt_state, model)

        model = eqx.apply_updates(model, updates)
        return model, opt_state, loss_value

    # Loop over our training dataset as many times as we need.
    def infinite_trainloader():
        while True:
            yield from trainloader


    writer = SummaryWriter()

    for step, (x, y) in zip(range(steps), infinite_trainloader()):
        # PyTorch dataloaders give PyTorch tensors by default,
        # so convert them to NumPy arrays.
        x = x.numpy()
        y = y.numpy()
        model, opt_state, train_loss = make_step(model, opt_state, x, y)

        
        if (step % print_every) == 0 or (step == steps - 1):
            test_loss, test_accuracy = evaluate(model, testloader)
            train_loss, train_accuracy = evaluate(model, trainloader)

            writer.add_scalar("train loss", float(train_loss), step)
            writer.add_scalar("test loss", float(test_loss), step)

            writer.add_scalar("train accuracy", float(train_accuracy), step)
            writer.add_scalar("test accuracy", float(test_accuracy), step)

            print(
                f"train_loss={train_loss.item()}, train_accuracy={train_accuracy.item()} "
                f"test_loss={test_loss.item()}, test_accuracy={test_accuracy.item()}"
            )
    return model

In [None]:
print("type loss value")

print(float(evaluate(agent1, testloader_Agent1)[0]))
print(type(float(evaluate(agent1, testloader_Agent1)[0])))


In [None]:

models = [agent1, agent2]

def train(
    model: CNN,
    trainloader: torch.utils.data.DataLoader,
    testloader: torch.utils.data.DataLoader,
    optim: optax.GradientTransformation,
    steps: int,
    print_every: int,
) -> CNN:
    # Just like earlier: It only makes sense to train the arrays in our model,
    # so filter out everything else.
    opt_state = optim.init(eqx.filter(model, eqx.is_array))

    # Always wrap everything -- computing gradients, running the optimiser, updating
    # the model -- into a single JIT region. This ensures things run as fast as
    # possible.
    @eqx.filter_jit
    def make_step(
        model: CNN,
        opt_state: PyTree,
        x: Float[Array, "batch 1 28 28"],
        y: Int[Array, " batch"],
    ):
        loss_value, grads = eqx.filter_value_and_grad(loss)(model, x, y)
        updates, opt_state = optim.update(grads, opt_state, model)

        model = eqx.apply_updates(model, updates)
        return model, opt_state, loss_value

    # Loop over our training dataset as many times as we need.
    def infinite_trainloader():
        while True:
            yield from trainloader


    writer = SummaryWriter()

    for step, (x, y) in zip(range(steps), infinite_trainloader()):
        # PyTorch dataloaders give PyTorch tensors by default,
        # so convert them to NumPy arrays.
        x = x.numpy()
        y = y.numpy()
        model, opt_state, train_loss = make_step(model, opt_state, x, y)
    
        test_loss = 0
        test_accuracy = 0
        train_loss = 0
        train_accuracy = 0

        for model in models:
            test_loss, test_accuracy += float(evaluate(model, testloader)[0]), float(evaluate(model, testloader)[1])
            train_loss, train_accuracy += float(evaluate(model, trainloader)[0]), float(evaluate(model, trainloader)[1])

        writer.add_scalar("global train loss", sum(train_loss)/len(train_loss), step)   # printing means
        writer.add_scalar("global test loss", sum(test_loss)/len(test_loss), step)

        writer.add_scalar("global train accuracy", sum(train_accuracy)/len(train_accuracy), step)
        writer.add_scalar("global test accuracy", sum(test_accuracy)/len(test_accuracy), step)

        if (step % print_every) == 0 or (step == steps - 1):
            print(
                f"global train_loss={sum(train_loss)}, gloabal train_accuracy={sum(train_accuracy)/len(train_accuracy)} "
                f"global test_loss={sum(test_loss)}, global test_accuracy={sum(test_accuracy)/len(test_accuracy)}"
            )


        optimizer = DOptimizer( models = models,
                                fs_private = fs_private,
                                len_memory = len(models),
                                memory_len = memory_len,
                                beta_c = beta_c,
                                beta_g = beta_g,
                                beta_gm = beta_gm)

        grads_list = []
        for model in models:
            grads_list.append(eqx.filter_value_and_grad(loss)(model, x, y)[1])
        
        model = optimizer.step_withMemory(models, grads_list)

    return model

In [None]:
import inspect
import os


optim = optax.adamw(LEARNING_RATE)

print(optax.adamw(LEARNING_RATE).update.__globals__)

optax.adam(LEARNING_RATE).update()

model = train(model, trainloader, testloader, optim, STEPS, PRINT_EVERY)

In [None]:
optax.adam({"e":0})

In [None]:
print(model.layers[0].weight)

Archive

- 0.38 loss - 3.5 min

BATCH_SIZE = 64

LEARNING_RATE = 3e-4

STEPS = 300

PRINT_EVERY = 30

SEED = 5678


- fucked

BATCH_SIZE = 64

LEARNING_RATE = 3e-2

STEPS = 300

PRINT_EVERY = 30

SEED = 5678


- 0.38 loss

BATCH_SIZE = 64

LEARNING_RATE = 2*3e-4

STEPS = 300

PRINT_EVERY = 30

SEED = 5678


- 0.15 loss

BATCH_SIZE = 64

LEARNING_RATE = 0.005

STEPS = 300

PRINT_EVERY = 30

SEED = 5678