In [1]:
# install dependencies
!pip install flax optax

Collecting flax
  Downloading flax-0.4.1-py3-none-any.whl (184 kB)
[?25l[K     |█▉                              | 10 kB 32.2 MB/s eta 0:00:01[K     |███▋                            | 20 kB 38.6 MB/s eta 0:00:01[K     |█████▍                          | 30 kB 40.1 MB/s eta 0:00:01[K     |███████▏                        | 40 kB 27.3 MB/s eta 0:00:01[K     |█████████                       | 51 kB 21.0 MB/s eta 0:00:01[K     |██████████▊                     | 61 kB 24.0 MB/s eta 0:00:01[K     |████████████▌                   | 71 kB 23.9 MB/s eta 0:00:01[K     |██████████████▎                 | 81 kB 25.0 MB/s eta 0:00:01[K     |████████████████                | 92 kB 26.8 MB/s eta 0:00:01[K     |█████████████████▉              | 102 kB 28.7 MB/s eta 0:00:01[K     |███████████████████▋            | 112 kB 28.7 MB/s eta 0:00:01[K     |█████████████████████▍          | 122 kB 28.7 MB/s eta 0:00:01[K     |███████████████████████▏        | 133 kB 28.7 MB/s eta 0:00:01

In [2]:
import jax
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)
import jax.lax as lax

import optax

import flax
from flax import linen as nn

import torchvision

import numpy as np  # TO DO remove np's -> jnp
import contextlib

from typing import Tuple, Union, List, OrderedDict, Callable
from dataclasses import field

# jaxopt has already implicit differentiation!!
import time

from matplotlib import pyplot as plt


In [3]:
# utility function for local random seeding
@contextlib.contextmanager
def np_temp_seed(seed):
	state = np.random.get_state()
	np.random.seed(seed)
	try:
		yield
	finally:
		np.random.set_state(state)

In [4]:
def _safe_norm_jax(v):
    if not jnp.all(jnp.isfinite(v)):
        return jnp.inf
    return jnp.linalg.norm(v)

def scalar_search_armijo_jax(phi, phi0, derphi0, c1=1e-4, alpha0=1, amin=0):
    ite = 0
    phi_a0 = phi(alpha0)    # First do an update with step size 1
    if phi_a0 <= phi0 + c1*alpha0*derphi0:
        return alpha0, phi_a0, ite

    # Otherwise, compute the minimizer of a quadratic interpolant
    alpha1 = -(derphi0) * alpha0**2 / 2.0 / (phi_a0 - phi0 - derphi0 * alpha0)
    phi_a1 = phi(alpha1)

    # Otherwise loop with cubic interpolation until we find an alpha which
    # satisfies the first Wolfe condition (since we are backtracking, we will
    # assume that the value of alpha is not too small and satisfies the second
    # condition.
    while alpha1 > amin:       # we are assuming alpha>0 is a descent direction
        factor = alpha0**2 * alpha1**2 * (alpha1-alpha0)
        a = alpha0**2 * (phi_a1 - phi0 - derphi0*alpha1) - \
            alpha1**2 * (phi_a0 - phi0 - derphi0*alpha0)
        a = a / factor
        b = -alpha0**3 * (phi_a1 - phi0 - derphi0*alpha1) + \
            alpha1**3 * (phi_a0 - phi0 - derphi0*alpha0)
        b = b / factor

        alpha2 = (-b + jnp.sqrt(jnp.abs(b**2 - 3 * a * derphi0))) / (3.0*a)
        phi_a2 = phi(alpha2)
        ite += 1

        if (phi_a2 <= phi0 + c1*alpha2*derphi0):
            return alpha2, phi_a2, ite

        if (alpha1 - alpha2) > alpha1 / 2.0 or (1 - alpha2/alpha1) < 0.96:
            alpha2 = alpha1 / 2.0

        alpha0 = alpha1
        alpha1 = alpha2
        phi_a0 = phi_a1
        phi_a1 = phi_a2

    # Failed to find a suitable step length
    return None, phi_a1, ite


def line_search_jax(update, x0, g0, g, nstep=0, on=True):
    """
    `update` is the propsoed direction of update.

    Code adapted from scipy.
    """
    tmp_s = [0]
    tmp_g0 = [g0]
    tmp_phi = [jnp.linalg.norm(g0)**2]
    s_norm = jnp.linalg.norm(x0) / jnp.linalg.norm(update)

    def phi(s, store=True):
        if s == tmp_s[0]:
            return tmp_phi[0]    # If the step size is so small... just return something
        x_est = x0 + s * update
        g0_new = g(x_est)
        phi_new = _safe_norm_jax(g0_new)**2
        if store:
            tmp_s[0] = s
            tmp_g0[0] = g0_new
            tmp_phi[0] = phi_new
        return phi_new
    
    if on:
        s, phi1, ite = scalar_search_armijo_jax(phi, tmp_phi[0], -tmp_phi[0], amin=1e-2)
    if (not on) or s is None:
        s = 1.0
        ite = 0

    x_est = x0 + s * update
    if s == tmp_s[0]:
        g0_new = tmp_g0[0]
    else:
        g0_new = g(x_est)
    return x_est, g0_new, x_est - x0, g0_new - g0, ite



In [5]:
def rmatvec_jax(part_Us, part_VTs, x):
    # Compute x^T(-I + UV^T)
    # x: (N, 2d, L')
    # part_Us: (N, 2d, L', threshold)
    # part_VTs: (N, threshold, 2d, L')
    if jnp.size(part_Us) == 0:
        return -x
    xTU = jnp.einsum('bij, bijd -> bd', x, part_Us)   # (N, threshold)
    return -x + jnp.einsum('bd, bdij -> bij', xTU, part_VTs)    # (N, 2d, L'), but should really be (N, 1, (2d*L'))

def matvec_jax(part_Us, part_VTs, x):
    # Compute (-I + UV^T)x
    # x: (N, 2d, L')
    # part_Us: (N, 2d, L', threshold)
    # part_VTs: (N, threshold, 2d, L')
    if jnp.size(part_Us) == 0:
        return -x
    VTx = jnp.einsum('bdij, bij -> bd', part_VTs, x)  # (N, threshold)
    return -x + jnp.einsum('bijd, bd -> bij', part_Us, VTx)     # (N, 2d, L'), but should really be (N, (2d*L'), 1)


In [6]:
def broyden_jax(f, x0, threshold, eps=1e-3, stop_mode="rel", ls=False, name="unknown"):
    bsz, total_hsize, seq_len = x0.shape
    g = lambda y: f(y) - y
    dev = x0.device()
    alternative_mode = 'rel' if stop_mode == 'abs' else 'abs'
    
    x_est = x0           # (bsz, 2d, L')
    gx = g(x_est)        # (bsz, 2d, L')
    nstep = 0
    tnstep = 0
    
    # For fast calculation of inv_jacobian (approximately)
    Us = jax.device_put(jnp.zeros((bsz, total_hsize, seq_len, threshold)),dev)     # One can also use an L-BFGS scheme to further reduce memory
    VTs = jax.device_put(jnp.zeros((bsz, threshold, total_hsize, seq_len)),dev)
    update = -matvec_jax(Us[:,:,:,:nstep], VTs[:,:nstep], gx)      # Formally should be -torch.matmul(inv_jacobian (-I), gx)
    prot_break = False
    
    # To be used in protective breaks
    protect_thres = (1e6 if stop_mode == "abs" else 1e3) * seq_len
    new_objective = 1e8

    trace_dict = {'abs': [],
                  'rel': []}
    lowest_dict = {'abs': 1e8,
                   'rel': 1e8}
    lowest_step_dict = {'abs': 0,
                        'rel': 0}
    nstep, lowest_xest, lowest_gx = 0, x_est, gx

    while nstep < threshold:
        x_est, gx, delta_x, delta_gx, ite = line_search_jax(update, x_est, gx, g, nstep=nstep, on=ls)
        nstep += 1
        tnstep += (ite+1)
        abs_diff = jnp.linalg.norm(gx)
        rel_diff = abs_diff / (jnp.linalg.norm(gx + x_est) + 1e-9)
        diff_dict = {'abs': abs_diff,
                     'rel': rel_diff}
        trace_dict['abs'].append(abs_diff)
        trace_dict['rel'].append(rel_diff)
        for mode in ['rel', 'abs']:
            if diff_dict[mode] < lowest_dict[mode]:
                if mode == stop_mode: 
                    lowest_xest, lowest_gx = lax.stop_gradient(x_est.copy()), lax.stop_gradient(gx.copy())
                lowest_dict[mode] = diff_dict[mode]
                lowest_step_dict[mode] = nstep

        new_objective = diff_dict[stop_mode]
        if new_objective < eps: break
        if new_objective < 3*eps and nstep > 30 and np.max(trace_dict[stop_mode][-30:]) / np.min(trace_dict[stop_mode][-30:]) < 1.3:
            # if there's hardly been any progress in the last 30 steps
            break
        if new_objective > trace_dict[stop_mode][0] * protect_thres:
            prot_break = True
            break

        part_Us, part_VTs = Us[:,:,:,:nstep-1], VTs[:,:nstep-1]
        vT = rmatvec_jax(part_Us, part_VTs, delta_x)
        u = (delta_x - matvec_jax(part_Us, part_VTs, delta_gx)) / jnp.einsum('bij, bij -> b', vT, delta_gx)[:,None,None]
        vT = jnp.nan_to_num(vT,nan=0.)
        u = jnp.nan_to_num(u,nan=0.)
        VTs = VTs.at[:,nstep-1].set(vT)
        Us = Us.at[:,:,:,nstep-1].set(u)
        update = -matvec_jax(Us[:,:,:,:nstep], VTs[:,:nstep], gx)

    # Fill everything up to the threshold length
    for _ in range(threshold+1-len(trace_dict[stop_mode])):
        trace_dict[stop_mode].append(lowest_dict[stop_mode])
        trace_dict[alternative_mode].append(lowest_dict[alternative_mode])

    return {"result": lowest_xest,
            "lowest": lowest_dict[stop_mode],
            "nstep": lowest_step_dict[stop_mode],
            "prot_break": prot_break,
            "abs_trace": trace_dict['abs'],
            "rel_trace": trace_dict['rel'],
            "eps": eps,
            "threshold": threshold}


def newton_jax(f, x0, threshold, eps=1e-3, stop_mode="rel", name="unknown"):

    g = lambda y: f(y) - y
    jac_g = jax.jacfwd(g)
    x = x0
    gx = g(x)
    gx_norm = jnp.linalg.norm(gx)
    nstep = 0
    # print(gx_norm)

    while nstep < threshold:
      # solve system
      delta_x = jnp.linalg.solve(jac_g(x),-g(x))
      x = x + delta_x
      gx = g(x)
      gx_norm = jnp.linalg.norm(gx)
      nstep += 1
      # print(gx_norm)

    return x, gx, gx_norm

In [7]:
class MDEQBlock(nn.Module):
    input: jnp.DeviceArray
    input_dim: int = 8
    hidden_dim: int = 2*input_dim
    kernel_size: Tuple[int] = (3, 3)  # can also be (5, 5), modify later
    num_groups: int = 2
    curr_branch: int = 0
    kernel_init = jax.nn.initializers.glorot_normal()
    bias_init = jax.nn.initializers.glorot_normal()

    
    def setup(self, i, num_channels):  
        self.input_dim = num_channels
        self.hidden_dim =  2*self.input_dim
        self.curr_branch = i

        # init-substitute for flax
        self.conv1 = nn.Conv(features=self.hidden_dim, kernel_size=self.kernel_size,
                             strides=1, kernel_init=self.kernel_init, bias_init=self.bias_init)
        self.group1 = nn.GroupNorm(num_groups=self.num_groups, group_size=self.hidden_dim)
        self.relu = nn.relu()
        self.conv2 = nn.Conv(features=self.input_dim, kernel_size=self.kernel_size,
                             strides=1, kernel_init=self.kernel_init, bias_init=self.bias_init)
        self.group2 = nn.GroupNorm(num_groups=self.num_groups, group_size=self.input_dim)
        self.group3 = nn.GroupNorm(num_groups=self.num_groups, group_size=self.input_dim)


    def __call__(self, x, branch, injection):
        # forward pass
        h1 = self.group1(self.conv1(x))
        h1 = self.relu(h1)
        
        h2 = self.conv2(z)
        if branch == 0:
            h2 += injection
        h2 = self.group2(h2)
        h2 += x
        
        h3 = self.relu(h2)
        out = self.group3(h3)

        return out


    
''' 
    assert statement we'll need    
    assert that the number of branches == len(input_channel_vector)
    assert also that num_branches == len(kernel_size_vector)
'''

" \n    assert statement we'll need    \n    assert that the number of branches == len(input_channel_vector)\n    assert also that num_branches == len(kernel_size_vector)\n"

In [8]:
class DownSample(nn.Module):
    def setup(self, branches, channel_dimensions, num_groups):
        self.in_chan, self.out_chan = channel_dimensions
        self.num_groups = num_groups

    def _downsample(self, branches):
        from_res, to_res = branches  # sampling from resolution from_res to to_res
        num_samples = to_res - from_res
        assert num_samples > 0

        down_block = []

        for n in range(len(num_samples)):
            inter_chan = self.in_chan if n < num_samples-1 else self.out_chan
            conv_down = nn.Conv(features=inter_chan, kernel_size=3, strides=2, padding=1,
                               bias=False)
            group_down = nn.GroupNorm(num_groups=self.num_groups,
                                      group_size=inter_chan)
            relu_down = nn.relu()
            module_list = [conv_down, group_down]
            if n < num_samples - 1:
                module = nn.Sequential([conv_down,
                                        group_down,
                                        relu_down])
            else:
                module = nn.Sequential([conv_down,
                                        group_down])
            down_block.append(module)
        return nn.Sequential(down_block)

    def __call__(self, branches, z_plus):
        downsample = self._downsample(branches)
        return downsample(z_plus)


In [9]:
class UpSample(nn.Module):
    def setup(self, channel_dimensions, num_groups):
        self.in_chan, self.out_chan = channel_dimensions
        self.num_groups = num_groups

    def _upsample(self, branches):
        from_res, to_res = branches  # sampling from resolution from_res to to_res
        num_samples = from_res - to_res
        assert num_samples > 0

        return [nn.Conv(features=self.out_chan, kernel_size=1, bias=False),
                        nn.GroupNorm(num_groups=self.num_groups, group_size=self.out_chan),
                        nn.Upsample(scale_factor=2**num_samples)]

    def __call__(self, branches, z_plus):
        upsample = self._upsample(branches)
        return upsample(z_plus)

In [10]:
class f_theta(nn.Module):
    features: Tuple[int] = (16, 4)
    num_groups: int = 8
    channels: List[int] = field(default_factory=lambda:[24, 24, 24])
    kernel_init = 
    # branches: List[int] = field(default_factory=lambda:[24, 24, 24])

    def setup(self):

        self.res_block = MDEQBlock()
        self.downsample = DownSample(channel_dimensions=self.channels,
                                     num_groups=self.num_groups)
        self.upsample = UpSample(channel_dimensions=self.channels,
                                 num_groups=self.num_groups)

        self.branches = self.stack_branches()
        self.num_branches = len(self.branches)

        self.fuse_branches = self.fuse()
        self.transform = self.transform_output()

    def stack_branches(self):
        branches = []
        for i in self.num_branches:
          branches.append(MDEQBlock(i, self.channels[i]))
        return branches

    def fuse(self, z_plus, channel_dimensions):
        # up- and downsampling stuff
        # z_plus: output of residual block
        if self.num_branches == 1:
            return None
        out = 1
        fuse_layers = []
        for i in range(self.num_branches):
            array = []
            for j in range(self.num_branches):
                if i == j:
                    array.append(z_plus[i])
                else:
                    if i < j:
                        sampled = self.downsample(z_plus=z_plus, branches=(i, j),
                                                 channel_dimension=channel_dimensions)
                    elif i > j:
                        sampled = self.upsample(z_plus=z_plus, branches=(i, j),
                                                 channel_dimension=channel_dimensions)
                    array.append(nn.Module(sampled))
            # fuse_layers.append(nn.Module(array))
            fuse_layers.append(array)

        return fuse_layers
    
    def transform_output(self):
        transforms = []
        for i in range(self.num_branches):
          transforms.append(nn.Sequential([nn.relu(),
                                          nn.Conv(features=self.channels[i], kernel_size=1, bias=False),
                                          nn.GroupNorm(num_groups=self.num_groups//2,
                                                       group_size=self.channels[i])]))
        
        return transforms

    def __call__(self, x, injection):
        # step 1: compute residual blocks
        branch_outputs = []
        for i in range(self.num_branches):
            branch_outputs.append(self.branches[i](x[i], i, injection[i])) # z, branch, x

        # step 2: fuse residual blocks
        fuse_outputs = []
        for i in range(self.num_branches):
          intermediate_i = 0
          for j in range(self.num_branches):
            if i == j:
              intermediate_i += branch_outputs[i]
            else:
              intermediate_i += self.fuse[i][j](branch_outputs[j])
          fuse_outputs.append(self.transform[i](intermediate_i))

        return fuse_outputs


    

In [11]:
class MDEQModel(nn.Module):
    features: Tuple[int] = (16, 4)
    num_branches: int = 3
    num_groups: int = 8
    channels: List[int] = field(default_factory=lambda:[24,24,24])
    branches: List[int] = field(default_factory=lambda:[1,1,1])
    training: bool = True
    solver_fn: Callable = broyden_jax
    kernel_init = jax.nn.initializers.glorot_normal()
    bias_init = jax.nn.initializers.glorot_normal()


    def setup(self):
        self.num_branches = len(self.branches)
        self.transform = [nn.Sequential(OrderedDict([
                                                     (nn.Conv(features=self.channels[i],
                                                              kernel_size=3, stride=1,
                                                              kernel_init=self.kernel_init,
                                                              bias_init=self.bias_init)),
                                                     (nn.BatchNorm()),
                                                     (nn.relu()),
                                                     (nn.Conv(features=self.channels[i], 
                                                              kernel_size=3, stride=1,
                                                              kernel_init=self.kernel_init,
                                                              bias_init=self.bias_init)),
                                                     (nn.BatchNorm()),
                                                     (nn.relu())])) for i in range(self.num_branches)]
        self.model = f_theta()
        
    def __call__(self, x):
        x = self.transform(x)
        x_list = [x]
        for i in range(self.num_branches):
            bs, _, H, W = x_list[-1].shape
            x_list.append(np.zeros(bs, self.channels[i], H//2, W//2))
        z_list = [np.zeros(elem) for elem in x_list]

        bsz = x.shape[0]
        func = lambda z: self.model(z_list, x_list)
        z_vec = jnp.cat([elem.reshape(bsz, -1, 1) for elem in z_list], dim=1)
        result = self.solver(func, z_vec, threshold=0.001)
        z_vec = result['result']
        output = z_vec
        if self.training:
            output = func(z_vec.requires_grad_())
        # jac_loss = jac_loss_estimate(output, z1) # comes from the follow-up paper
        jac_loss = None
        
        y_list = output # TO DO -- for now without dropout!
        return y_list, jac_loss

In [12]:
def transform(image, label, num_classes=10):
    image = jnp.float32(image) / 255.
    label = jax.nn.one_hot(label, num_classes=num_classes)
    return image, label

def load_data():
    test_ds = torchvision.datasets.CIFAR10(root="data", train=False,download=True)
    train_ds = torchvision.datasets.CIFAR10(root="data", train=True,download=True)

    train_images, train_labels = transform(train_ds.data[:1000], train_ds.targets[:1000])
    test_images, test_labels = transform(test_ds.data[:200], test_ds.targets[:200])
    return train_images, train_labels, test_images, test_labels

In [17]:
def train():
    '''
    extra thing: warm-up using gradient descent in pytorch code of official repo
    --> check impact of that and maybe also cost etc (eg if only one layer etc)
    '''

    max_itr = 1000
    print_interval = 100

    data = load_data()
    data_size = data.shape[0]
    batch_size = 100

    solver_fn = broyden_jax

    my_deq = MDEQModel(solver_fn=solver_fn)
    print(my_deq)

    # def cross_entropy_loss(*, logits, labels):
    def cross_entropy_loss(logits, labels):
        ''' 
        should be same as  optax.softmax_cross_entropy(logits, labels); 
        if getting funny results maybe remove log of logits
        '''
        one_hot_labels = jax.nn.one_hot(labels, num_classes=10)
        return -jnp.mean(jnp.sum(one_hot_labels * jnp.log(logits), axis=-1))

    png = jax.random.PRNGKey(0)
    png, key = jax.random.split(png, 2)
    dummy_input = jnp.ones(shape=(batch_size, 32, 32, 3))
    #init_weights = init(key, (shape of weights), jnp.float64)
    # x0 = random.glorot_normal(key, shape=(1,))

    weights = my_deq.init(png, dummy_input)
    optimizer = optax.adamw(learning_rate=0.001, weight_decay=0.001)
    opt_state = optimizer.init(weights)

    loss = cross_entropy_loss

    def step(weights, opt_state, x_batch, y_true):
        loss_vals, grad = jax.value_and_grad(loss, has_aux=True)(weights, x_batch, y_true)
        updates, opt_state = optimizer.update(grad, opt_state, weights)
        weights = optax.apply_updates(weights, updates)
        return weights, opt_state, loss_vals



    def generator(batch_size: int=10):
        ''' https://optax.readthedocs.io/en/latest/meta_learning.html?highlight=generator#meta-learning '''
        rng = jax.random.PRNGKey(0)

        while True:
            rng, k1 = jax.random.split(rng, num=2)
            idxs = jax.random.uniform(k1, shape=(batch_size), minval=0, maxval=data_size, dtype=jnp.int32)
            yield idxs

    g = generator(batch_size=batch_size)

    for itr in range(max_itr):
        batch_idxs = next(g)
        x_batch = data[batch_idxs]
        params, opt_state, loss_vals = step(params, opt_state, x_batch)
        
        if itr % print_interval == 0:
            print("\tat step", itr, "have loss", loss_vals)

        if loss_vals < 1e-5:
            break
    

In [18]:
train()

Files already downloaded and verified
Files already downloaded and verified


TypeError: ignored

Breakdown of code overall:


*   MDEQ modul
*   List item

