In [1]:
import collections
import enum

import numpy as np
import torch
print(torch.__version__)

import jax.numpy as jnp
import jax

import rich
import sklearn.datasets
import tqdm.notebook as tqdm

In [44]:
class CVSets(enum.Enum):
    TRAIN = 0
    VALID = 1


class LearningMode(enum.Enum):
    GRADIENT_DESCENT = 0
    NEWTON = 1


def cross_entropy_fn(batch_data, batch_target, model_fn, model_weights, model_biases):
    data = batch_data
    print(f"{batch_target.shape = }")
    target = jnp.expand_dims(batch_target, -1)

    logits = model_fn(data, model_weights, model_biases)
    log_softmax = jax.nn.log_softmax(logits)

    print(f"{log_softmax.shape = }")
    print(f"{target.shape = }")
    loss = - jnp.take_along_axis(log_softmax, target, axis=-1)
    loss = loss.mean()

    return loss, logits


def collate_fn(uncollated):
    keys = uncollated[0].keys()
    output = {}
    for key in keys:
        output_array = []
        for entry in uncollated:
            output_array.append(entry[key])    
        output[key] = jnp.stack(output_array)

    return output


class Digits(torch.utils.data.Dataset):
    def __init__(self, data):
        assert isinstance(data, dict), type(data).mro()
        assert isinstance(data["data"], np.ndarray), type(data).mro()
        self._data = data

    def __getitem__(self, *args, **kwargs):
        return {
            k: v.__getitem__(*args, **kwargs) 
            for k, v in self._data.items()
        }

    def __len__(self, ):
        return len(self._data["data"])


@jax.jit
def model_fn(x, model_weights, model_biases):
    return x @ model_weights + model_biases


vmap_in_axes = [0, 0, None, None]
vmap_out_axes = 0

DO_VMAP = True

def maybe_vmap(fun, *args, **kwargs):
    
    if DO_VMAP:
        return jax.vmap(fun, *args, **kwargs)
    else:
        return fun

print("grad_w")
grad_w = maybe_vmap(
    jax.grad(
    fun=lambda batch_data, batch_target, weights, biases: 
        cross_entropy_fn(
            batch_data,
            batch_target,
            model_fn, 
            weights,
            biases,
        ),
    argnums=2,
    has_aux=True,
), 
in_axes=vmap_in_axes,
out_axes=vmap_out_axes,
)

print("grad_b")
grad_b = maybe_vmap(
    jax.grad(
        fun=lambda batch_data, batch_target, weights, biases: 
            cross_entropy_fn(
                batch_data,
                batch_target,
                model_fn, 
                weights,
                biases,
            ),
        argnums=3,
        has_aux=True,
), 
out_axes=vmap_out_axes,
in_axes=vmap_in_axes,
)

grad_w = grad_w
grad_b = grad_b

def grad_desc(weights, biases, batch_data, batch_target, lr):
    assert weights.ndim == 2, weights.shape
    assert biases.ndim == 1, weights.shape

    assert batch_data.ndim == 1, batch_data.shape
    assert batch_target.ndim == 0, batch_target.shape

    vmap_grad_w = grad_w(batch_data, batch_target, weights, biases)[0]
    vmap_grad_b = grad_b(batch_data, batch_target, weights, biases)[0]

    new_weights = weights - lr * vmap_grad_w
    new_biases = biases - lr * vmap_grad_b

    assert new_weights.ndim == 2, new_weights.shape
    assert new_biases.ndim == 1, new_biases.shape

    assert new_weights.shape
    assert new_biases.shape

    return new_weights, new_biases

# hessian_w = jax.hessian(
#     fun=lambda batch_data, batch_target, weights, biases: 
#         cross_entropy_fn(
#             batch_data, 
#             batch_target,
#             model_fn, 
#             weights,
#             biases,
#         ),
#     argnums=3,
#     has_aux=True,
# )

# hessian_b = jax.hessian(
#     fun=lambda batch_data, batch_target, weights, biases: 
#         cross_entropy_fn(
#             batch_data, 
#             batch_target,
#             model_fn, 
#             weights,
#             biases,
#         ),
#     argnums=4,
#     has_aux=True,
# )


# def newtons(weights, biases, batch_data, batch_target):
#     val_hessian_w = hessian_w(batch_data, batch_target, weights, biases)[0]
#     val_hessian_b = hessian_b(batch_data, batch_target, weights, biases)[0]
    
#     inv_hessian_w = jnp.linalg.inv(val_hessian_w)
#     inv_hessian_b = jnp.linalg.inv(val_hessian_b)

#     new_weights = weights - inv_hessian_w @ grad_w(batch_data, batch_target, weights, biases)[0]
#     new_biases  = biases  - inv_hessian_b @ grad_b(batch_data, batch_target, weights, biases)[0]

#     return new_weights, new_biases


def step(*, batch_data, batch_target, model_weights, model_biases, model_fn, cv_set, lr, learning_mode):
    print(
        f"loss comp: "
        f"\n\t - {vmap_in_axes = } "
        f"\n\t - {batch_data.shape = } "
        f"\n\t - {batch_target.shape = } "
        f"\n\t - {cv_set = }"
    )

    loss, logits = maybe_vmap(
        lambda batch_data, batch_target, weights, biases: 
            cross_entropy_fn(
                batch_data,
                batch_target,
                model_fn, 
                weights,
                biases,
            ),
        in_axes=vmap_in_axes,
        out_axes=0,
    )(batch_data, batch_target, model_weights, model_biases)

    accuracy = (logits.argmax(-1) == batch_target).mean()
    
    if cv_set == CVSets.TRAIN:
        if learning_mode == LearningMode.GRADIENT_DESCENT:

            print("Grad Desc:")
            model_weights, model_biases = jax.tree_map(
                lambda w, b: grad_desc(w, b, batch_data, batch_target, lr),
                model_weights,
                model_biases,
            )

        elif learning_mode == LearningMode.NEWTON:
            model_weights, model_biases = jax.tree_map(
                lambda weights, biases: 
                    newtons(weights, biases, batch_data, batch_target),
                model_weights,
                model_biases,
            )

    return loss, dict(accuracy=accuracy), model_weights, model_biases


def train():
    NUM_EPOCHS = 5
    LEARNING_RATE = 0.1
    LEARNING_MODE = LearningMode.GRADIENT_DESCENT # LearningMode.NEWTON # 
    PRNG_KEY = jax.random.PRNGKey(0)

    print("Loading digits.")
    digits = sklearn.datasets.load_digits()
    print("Done loading digits.")

    delim = int(len(digits.data) * .8)
    train_data = digits.data[:delim]
    std = np.std(train_data, axis=0)
    std[std == 0] = 1

    mean = np.mean(train_data, axis=0)
    train_data -= mean
    train_data /= std
    valid_data = digits.data[delim:]
    valid_data -= mean
    valid_data /= std

    datasets = {
        CVSets.TRAIN: Digits(dict(data=train_data, target=digits.target[:delim])),
        CVSets.VALID: Digits(dict(data=valid_data, target=digits.target[delim:])),
    }

    dataloaders = {
        cv_set: torch.utils.data.DataLoader(
            dataset,
            batch_size=10,
            shuffle=False,
            collate_fn=collate_fn
        )
        
        for cv_set, dataset
        in datasets.items()
    }

    model_weights = jax.random.uniform(PRNG_KEY, shape=(64, 10,))
    model_biases = jax.random.uniform(PRNG_KEY, shape=(10,))
    

    for epoch in range(NUM_EPOCHS):
        for cv_set in [CVSets.TRAIN, CVSets.VALID]:
            all_metrics = collections.defaultdict(list)
            for batch in tqdm.tqdm(dataloaders[cv_set]):
                print("\nStep:")
                loss, metrics, model_weights, model_biases = step(
                    batch_data=batch["data"],
                    batch_target=batch["target"],
                    cv_set=cv_set,
                    model_fn=model_fn,
                    model_weights=model_weights,
                    model_biases=model_biases,
                    learning_mode=LEARNING_MODE,
                    lr=LEARNING_RATE,
                )
                assert model_weights.ndim == 2, model_weights.ndim
                assert model_biases.ndim == 1, model_biases.ndim

                for k, v in metrics.items():
                    all_metrics[k].append(v)
            
            for k, v in all_metrics.items():
                rich.print(f"[bold green]({cv_set})[/] Epoch {epoch}: [bold blue]{k.title()}:[/] {np.mean(v):0.2%}")

train()

            


grad_w
grad_b
Loading digits.
Done loading digits.


  0%|          | 0/144 [00:00<?, ?it/s]


Step:
loss comp: 
	 - vmap_in_axes = [0, 0, None, None] 
	 - batch_data.shape = (10, 64) 
	 - batch_target.shape = (10,) 
	 - cv_set = <CVSets.TRAIN: 0>
batch_target.shape = ()
log_softmax.shape = (10,)
target.shape = (1,)
Grad Desc:


AssertionError: (10, 64)

In [3]:
jnp.ones(shape=(3, 3))

Array([[1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.]], dtype=float32)