In [None]:
!git clone -b without_rng_and_min_damping https://github.com/SheaCardozo/CSC2541_Project/ 

fatal: destination path 'CSC2541_Project' already exists and is not an empty directory.


In [None]:
!pip install git+https://github.com/deepmind/dm-haiku &> /dev/null
!pip uninstall -y kfac_jax &> /dev/null
!pip install git+https://github.com/SheaCardozo/kfac_jax &> /dev/null
!pip install optax &> /dev/null

from jax.config import config
config.update("jax_enable_x64", True)
#config.update("jax_log_compiles", True)

import warnings
warnings.filterwarnings("ignore")

import haiku as hk
from haiku.nets import ResNet18

import jax
from jax import numpy as jnp
from jax.tree_util import tree_reduce

import json
import kfac_jax
import optax
import os
import numpy as onp
import pickle
import tensorflow_datasets as tfds
import time

from functools import partial
from itertools import product
from matplotlib import pyplot as plt
from sklearn.metrics import accuracy_score
from tqdm import trange

from CSC2541_Project.hf.optimizer import hf

In [None]:
!nvidia-smi

Tue Apr 12 22:49:31 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   36C    P8    10W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

# Model and Utils

In [None]:
class DCNet(hk.Module):
    def __init__(self):
        super().__init__(name="DCNet")
        self.conv1 = hk.Conv2D(output_channels=64*1, kernel_shape=4, stride=2, padding="SAME", with_bias=False)
        self.bn1 = hk.BatchNorm(create_scale=True, create_offset=True, decay_rate=0.9, data_format="N...C")
        self.conv2 = hk.Conv2D(output_channels=64*2, kernel_shape=4, stride=2, padding="SAME", with_bias=False)
        self.bn2 = hk.BatchNorm(create_scale=True, create_offset=True, decay_rate=0.9, data_format="N...C")
        self.conv3 = hk.Conv2D(output_channels=64*4, kernel_shape=4, stride=2, padding="SAME", with_bias=False)
        self.bn3 = hk.BatchNorm(create_scale=True, create_offset=True, decay_rate=0.9, data_format="N...C")
        self.conv4 = hk.Conv2D(output_channels=64*8, kernel_shape=4, stride=2, padding="SAME", with_bias=False)
        self.bn4 = hk.BatchNorm(create_scale=True, create_offset=True, decay_rate=0.9, data_format="N...C")

        self.conv5 = hk.Conv2D(output_channels=64*16, kernel_shape=2, stride=1, padding="valid")
        self.fc = hk.Linear(10, with_bias=True)

    def __call__(self, x, training):
        x = self.conv1(x)
        x = self.bn1(x, is_training=training)
        x = jax.nn.leaky_relu(x, negative_slope=0.2)

        x = self.conv2(x)
        x = self.bn2(x, is_training=training)
        x = jax.nn.leaky_relu(x, negative_slope=0.2)

        x = self.conv3(x)
        x = self.bn3(x, is_training=training)
        x = jax.nn.leaky_relu(x, negative_slope=0.2)

        x = self.conv4(x)
        x = self.bn4(x, is_training=training)
        x = jax.nn.leaky_relu(x, negative_slope=0.2)


        x = self.conv5(x)
        x = jnp.reshape(x, [x.shape[0], -1])
        x = self.fc(x)

        return x

def unpickle(file):
    with open(file, 'rb') as f:
        d = pickle.load(f, encoding='bytes')
    return d

def get_datasets(dataset):
    cpus = jax.devices("cpu")

    if dataset == "mnist":
        ds_builder = tfds.builder('mnist')
        ds_builder.download_and_prepare()
        train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
        test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
        train_ds['image'] = (((jax.device_put(jnp.array(train_ds['image']), device=cpus[0]) / 255) - 0.5) / 0.5)
        test_ds['image'] = (((jax.device_put(jnp.array(test_ds['image']), device=cpus[0]) / 255) - 0.5) / 0.5)

        train_ds['image'] = jax.image.resize(train_ds['image'], (train_ds['image'].shape[0], 32, 32, 1), method='cubic').astype(jnp.float64)
        test_ds['image'] = jax.image.resize(test_ds['image'], (test_ds['image'].shape[0], 32, 32, 1), method='cubic').astype(jnp.float64)

        return train_ds, test_ds
    elif dataset == 'cifar10':        
        ds_builder = tfds.image_classification.Cifar10()
        ds_builder.download_and_prepare()
        train_ds = tfds.as_numpy(
            ds_builder.as_dataset(split='train', batch_size=-1))
        test_ds = tfds.as_numpy(
            ds_builder.as_dataset(split='test', batch_size=-1))
        
        train_ds['image'] = (((jax.device_put(jnp.array(train_ds['image']), device=cpus[0]) / 255) - 0.5) / 0.5)
        test_ds['image'] = (((jax.device_put(jnp.array(test_ds['image']), device=cpus[0]) / 255) - 0.5) / 0.5)

        train_ds['image'] = train_ds['image'].astype(jnp.float64)
        test_ds['image'] = test_ds['image'].astype(jnp.float64)

        return train_ds, test_ds

    else:
        raise ValueError("Not Implemented")

def setup_log (log_loss, optim_dir_name, optim_mag_name):

    log_loss[f'{optim_mag_name}#{optim_dir_name}'] = {}
    log_loss[f'{optim_mag_name}#{optim_dir_name}']["losses"] = []
    log_loss[f'{optim_mag_name}#{optim_dir_name}']["val_metric"] = []
    log_loss[f'{optim_mag_name}#{optim_dir_name}']["mag_m"] = []
    log_loss[f'{optim_mag_name}#{optim_dir_name}']["mag_d"] = []
    log_loss[f'{optim_mag_name}#{optim_dir_name}']["mag_inflate"] = []
    log_loss[f'{optim_mag_name}#{optim_dir_name}']["cos_sim"] = []

    return log_loss

# Config

In [None]:
model = "dcnet" # resnet18 or dcnet
dataset = "mnist" # cifar10 or mnist

iterations = 500 # number of batch iterations
lr = 1e-3 # learning rate for first order optimizers
batch_size = 1024 # batch size
checkpoint = 500 # model checkpoint every this number of iterations

# Initial and minimum damping parameters for kfac and hf adaptive damping
init_damp_kfac = 10.0
min_damp_kfac = 1.0

init_damp_hf = 10.0
min_damp_hf = 1.0

# L2 Regularization for kfac
L2_REG = 0

seed = 451 # random seed

layer_wise = True # If true graft layer-wise. If false do global grafting.

In [None]:
# Set up models and training data
# You will get an error here if a gpu is not attached
gpus = jax.devices("gpu")
cpus = jax.devices("cpu")

train, test = get_datasets(dataset)

key = jax.random.PRNGKey(seed=seed)
key, key_samp, key_init = jax.random.split(key, 3)

if model == "dcnet":
    classifier = hk.without_apply_rng(hk.transform_with_state(lambda x, training: DCNet()(x, training)))
elif model == "resnet18":
    classifier = hk.without_apply_rng(hk.transform_with_state(lambda x, training: ResNet18(num_classes=10)(x, training)))
else:
    raise ValueError("Not Implemented")

if dataset == "mnist":
    sample = jax.random.normal(key_samp, shape=(9, 32, 32, 1))
elif dataset == 'cifar10':
    sample = jax.random.normal(key_samp, shape=(9, 32, 32, 3))
else:
    raise ValueError("Not Implemented")

params_base, state_base = classifier.init(key_init, sample, training=True) 
params_base, state_base = jax.device_put(params_base, device=cpus[0]), jax.device_put(state_base, device=cpus[0])

test_img = test['image']
test_y = test['label']

[1mDownloading and preparing dataset mnist/3.0.1 (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /root/tensorflow_datasets/mnist/3.0.1...[0m


local data directory. If you'd instead prefer to read directly from our public
GCS bucket (recommended if you're running on GCP), you can instead pass
`try_gcs=True` to `tfds.load` or set `data_dir=gs://tfds-data/datasets`.



Dl Completed...:   0%|          | 0/4 [00:00<?, ? file/s]


[1mDataset mnist downloaded and prepared to /root/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.[0m
Instructions for updating:
Use `tf.data.Dataset.get_single_element()`.


Instructions for updating:
Use `tf.data.Dataset.get_single_element()`.


In [None]:
# Some important functions for the training loop
@jax.jit
def predict(params, state, images):
    probs, _ = classifier.apply(params, state, x=images, training=False)
    preds = jnp.argmax(probs, axis=1)
    return preds

def softmax_cross_entropy(logits: jnp.ndarray, targets: jnp.ndarray):
    kfac_jax.register_softmax_cross_entropy_loss(logits, targets)
    return optax.softmax_cross_entropy(logits, targets)

def kfac_loss_fn(params, state, batch):
    batch, labs = batch
    logits, state = classifier.apply(params, state, x=batch, training=True)
    loss = jnp.mean(softmax_cross_entropy(logits, labs)) + L2_REG * kfac_jax.utils.inner_product(params, params) / 2.0

    return loss, state

@jax.jit
def loss_fn(params, state, batch, labs):
    logits, state = classifier.apply(params, state, x=batch, training=True)
    return -jnp.mean(jnp.sum(jnp.log(jax.nn.softmax(logits)) * labs, axis=1)), state

jit_grad_fn = jax.value_and_grad(loss_fn, has_aux=True)

def get_grafted_train_step(optim_dir, optim_mag):
    @jax.jit
    def grafted_train_step(batch, labs, params, state, optim_state_dir, optim_state_mag, dir_kfac=None, mag_kfac=None, eps=1e-8):

        if dir_kfac is None or mag_kfac is None:
          grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
          (loss, state), grad = grad_fn(params, state, batch, labs)
        else:
          loss, state = loss_fn(params, state, batch, labs)

        if dir_kfac is not None:
            updates_dir, optim_state_dir = dir_kfac
        elif optim_state_dir is not None:
            updates_dir, optim_state_dir = optim_dir.update(grad, optim_state_dir, params)

        if mag_kfac is not None:
            updates_mag, optim_state_mag = mag_kfac
            if optim_state_dir is None:
                updates_dir = updates_mag
        else:
            updates_mag, optim_state_mag = optim_mag.update(grad, optim_state_mag, params)
        

        if layer_wise:
          mag_update = jax.tree_multimap(lambda x, y: x / (y + eps), 
                                      jax.tree_util.tree_map(jnp.linalg.norm, updates_mag), 
                                      jax.tree_util.tree_map(jnp.linalg.norm, updates_dir))
        else:
          global_update = jnp.linalg.norm(jnp.vstack(jax.tree_util.tree_flatten(jax.tree_util.tree_map(lambda x: x.reshape((-1, 1)), updates_mag))[0])) / \
                                            (jnp.linalg.norm(jnp.vstack(jax.tree_util.tree_flatten(jax.tree_util.tree_map(lambda x: x.reshape((-1, 1)), updates_dir))[0])) + eps)

          mag_update = jax.tree_util.tree_map(lambda x: global_update, updates_mag)
        
        
        cos_sim = jax.tree_multimap(lambda x, y: (jnp.dot(x.reshape((-1)), y.reshape((-1))) / \
                    (eps + jnp.linalg.norm(x.reshape((-1))) * jnp.linalg.norm(y.reshape((-1))))), \
                    updates_mag, updates_dir)
    
        
        updates = jax.tree_multimap(lambda x, y: x * y, mag_update, updates_dir)


        params = optax.apply_updates(params, updates)

        return jnp.mean(loss), params, state, optim_state_dir, optim_state_mag, \
            jax.tree_util.tree_map(jnp.linalg.norm, updates_mag), \
            jax.tree_util.tree_map(jnp.linalg.norm, updates_dir), cos_sim

    return grafted_train_step

In [None]:
# Direct kfac and hf args in case modification is needed'
kfac_args = {
    "value_and_grad_func": jax.value_and_grad(kfac_loss_fn, has_aux=True),
    "l2_reg": L2_REG,
    "value_func_has_aux": False,
    "value_func_has_state": True,
    "value_func_has_rng": False,
    "use_adaptive_learning_rate": True,
    "use_adaptive_momentum": True,
    "use_adaptive_damping": True,
    "initial_damping": init_damp_kfac,
    "min_damping": min_damp_kfac,
    "multi_device": False,
    "inverse_update_period": 1,
    "damping_adaptation_interval": 1
}

hf_args = {
    "precond": "uncentered",
    "lambd": init_damp_hf,
    "min_damp": min_damp_hf,
    "use_momentum": False,
    "line_search": True
}

# Optimization Loop

In [None]:
# To run, place the names and optax optimizer objects in the below lists.
# The code expects the names 'hf' and 'kfac' for those optimizers respectively,
# otherwise the names are arbitrary. Their optimizer objects should be left as None.
mags_optims_name = ['hf', 'hf', 'kfac', 'kfac'] 
mags_optims = [None, None, None, None]
dirs_optims_name = ['sgd', 'adam', 'sgd', 'adam']
dirs_optims = [optax.sgd(lr, momentum=0.9), optax.adam(lr), optax.sgd(lr, momentum=0.9), optax.adam(lr)]

for (optim_dir, optim_dir_name), (optim_mag, optim_mag_name) in zip(zip(dirs_optims, dirs_optims_name), zip(mags_optims, mags_optims_name)):
    os.mkdir(f"/content/drive/MyDrive/graft_test/{optim_mag_name}#{optim_dir_name}")
    log_loss = {}

    start = time.time()

    grafted_train_step = get_grafted_train_step(optim_dir, optim_mag)

    # Segment the train loop in it's own function - this was aimed to force memory 
    # used in one run to be released for the next, but it matters less in Colab
    def loop (optim_dir, optim_dir_name, optim_mag, optim_mag_name, log_loss, key):
        params, state = jax.device_put(params_base, device=gpus[0]), jax.device_put(state_base, device=gpus[0])

        log_loss = setup_log(log_loss, optim_dir_name, optim_mag_name)

        # If kfac and hf are treated differently as they do not
        # conform to Optax api, setup them separately
        if optim_mag_name == "kfac":
            optim_mag = kfac_jax.Optimizer(**kfac_args)
            if optim_dir_name == "kfac":
                optim_dir = None
        elif optim_dir_name == "kfac":
            optim_dir = kfac_jax.Optimizer(**kfac_args)

        if optim_mag_name == "hf":
            optim_mag = hf(classifier, loss_fn)
            if optim_dir_name == "hf":
                optim_dir = None
        elif optim_dir_name == "hf":
            optim_dir = hf(classifier, loss_fn)

        if optim_mag_name == "kfac":
            key, key_choice = jax.random.split(key)
            batch_inds = jax.random.choice(key=key_choice, a=train['image'].shape[0], shape=(batch_size,), replace=False).astype(jnp.int64)
            batch = jax.device_put(train['image'][batch_inds], device=gpus[0])
            labs = jax.nn.one_hot(jax.device_put(train['label'][batch_inds], device=gpus[0]).astype(jnp.int64), 10)  
            optim_state_mag = optim_mag.init(params, key, (batch, labs), func_state=state)
        elif optim_mag_name == "hf":
            optim_state_mag = optim_mag.init(params, **hf_args)
        else:
            optim_state_mag = optim_mag.init(params)

        if optim_dir_name == "kfac":
            if optim_mag_name == "kfac":
                optim_state_dir = None
            else:
                key, key_choice = jax.random.split(key)
                batch_inds = jax.random.choice(key=key_choice, a=train['image'].shape[0], shape=(batch_size,), replace=False).astype(jnp.int64)
                batch = jax.device_put(train['image'][batch_inds], device=gpus[0])
                labs = jax.nn.one_hot(jax.device_put(train['label'][batch_inds], device=gpus[0]).astype(jnp.int64), 10)   
                optim_state_dir = optim_dir.init(params, key, (batch, labs), func_state=state)
        elif optim_dir_name == "hf":
            if optim_mag_name == "hf":
                optim_state_dir = None
            else:
              optim_state_dir = optim_dir.init(params, **hf_args)
        else:
            optim_state_dir = optim_dir.init(params)
        
        # Train loop
        for i in trange(iterations + 1):

            # Batch
            key, key_choice = jax.random.split(key)
            batch_inds = jnp.array(jax.random.choice(key=key_choice, a=train['image'].shape[0], shape=(batch_size,), replace=False)).astype(jnp.int64)
            batch = jax.device_put(train['image'][batch_inds], device=gpus[0])
            labs = jax.nn.one_hot(jax.device_put(train['label'][batch_inds], device=gpus[0]).astype(jnp.int64), 10)

            # As with before - treat kfac and hf separately. Compute their 
            # update steps and pass them manually into the train step
            dir_kfac, mag_kfac = None, None

            if optim_mag_name == "kfac":
                _, mag_update, state, optim_state_mag, mag_stats = \
                  optim_mag.step(params, optim_state_mag, key, batch=(batch, labs), global_step_int=i, func_state=state)
                mag_kfac = mag_update, optim_state_mag
                if optim_dir_name == "kfac":
                  dir_kfac = mag_update, optim_state_mag
            elif optim_dir_name == "kfac":
                _, dir_update, state, optim_state_dir, dir_stats = \
                  optim_dir.step(params, optim_state_dir, key, batch=(batch, labs), global_step_int=i, func_state=state)
                dir_kfac = dir_update, optim_state_dir


            if optim_mag_name == "hf" or optim_dir_name == "hf":

              (_, state), hf_grad = jit_grad_fn(params, state, batch, labs)

              if optim_mag_name == "hf":

                mag_update, optim_state_mag = optim_mag.update(hf_grad, optim_state_mag, params, state, batch, labs)
                mag_kfac = mag_update, optim_state_mag
                if optim_dir_name == "hf":
                    dir_kfac = mag_update, optim_state_mag

              elif optim_dir_name == "hf":
                dir_update, optim_state_dir = optim_dir.update(hf_grad, optim_state_dir, params, state, batch, labs)
                dir_kfac = dir_update, optim_state_dir

            # Train step
            loss, params, state, optim_state_dir, optim_state_mag, mag_m, mag_d, cos_sim = \
                grafted_train_step(batch, labs, params, state, optim_state_dir, optim_state_mag, dir_kfac, mag_kfac)

            # Log iteration metrics
            log_loss[f'{optim_mag_name}#{optim_dir_name}']["losses"].append(loss.item())
            log_loss[f'{optim_mag_name}#{optim_dir_name}']["mag_m"].append(jnp.ravel(jnp.vstack(jax.tree_util.tree_flatten(mag_m)[0])).tolist())
            log_loss[f'{optim_mag_name}#{optim_dir_name}']["mag_d"].append(jnp.ravel(jnp.vstack(jax.tree_util.tree_flatten(mag_d)[0])).tolist())
            log_loss[f'{optim_mag_name}#{optim_dir_name}']["mag_inflate"].append(jnp.ravel(jnp.vstack(jax.tree_util.tree_flatten(mag_m)[0]) / (1e-8 + jnp.vstack(jax.tree_util.tree_flatten(mag_d)[0]))).tolist())
            log_loss[f'{optim_mag_name}#{optim_dir_name}']["cos_sim"].append(jnp.ravel(jnp.vstack(jax.tree_util.tree_flatten(cos_sim)[0])).tolist())

            # Compute and log validation metric (prediction accuracy)
            test_inds = jnp.array(jax.random.choice(key=key_choice, a=test_img.shape[0], shape=(batch_size,), replace=False)).astype(jnp.int64)
            preds = predict(params, state, jax.device_put(test_img[test_inds], device=gpus[0]))
            log_loss[f'{optim_mag_name}#{optim_dir_name}']["val_metric"].append(accuracy_score(jax.device_put(test_y[test_inds], device=gpus[0]), preds))

            # Save checkpoint if necessary
            if i % checkpoint == 0:
              with open(f"/content/drive/MyDrive/graft_test/{optim_mag_name}#{optim_dir_name}/params_{optim_mag_name}#{optim_dir_name}_{i}.pkl", 'wb') as f:
                  pickle.dump(params, f)

              with open(f"/content/drive/MyDrive/graft_test/{optim_mag_name}#{optim_dir_name}/state_{optim_mag_name}#{optim_dir_name}_{i}.pkl", 'wb') as f:
                  pickle.dump(state, f)

              with open(f"/content/drive/MyDrive/graft_test/{optim_mag_name}#{optim_dir_name}/optim_state_dir_{optim_mag_name}#{optim_dir_name}_{i}.pkl", 'wb') as f:
                  pickle.dump(optim_state_dir, f)

              with open(f"/content/drive/MyDrive/graft_test/{optim_mag_name}#{optim_dir_name}/optim_state_mag_{optim_mag_name}#{optim_dir_name}_{i}.pkl", 'wb') as f:
                  pickle.dump(optim_state_mag, f)
          
              with open(f"/content/drive/MyDrive/graft_test/{optim_mag_name}#{optim_dir_name}/log_loss_{optim_mag_name}#{optim_dir_name}_{i}.txt", 'w') as f:
                  f.write(json.dumps(log_loss)) 

        return log_loss, key

    # Run train loop
    log_loss, key = loop (optim_dir, optim_dir_name, optim_mag, optim_mag_name, log_loss, key)

    end = time.time()

    print(f"{optim_mag_name}#{optim_dir_name}, Time: {end - start}")

  1%|          | 6/501 [07:22<9:08:30, 66.49s/it]