In [620]:
# install dependencies
!pip install flax optax



In [621]:
import jax
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)
import jax.lax as lax
from jax import random

import optax

import flax
from flax import linen as nn

import torchvision

import numpy as np  # TO DO remove np's -> jnp
import contextlib

from typing import Tuple, Union, List, OrderedDict, Callable
from dataclasses import field

# jaxopt has already implicit differentiation!!
import time

from matplotlib import pyplot as plt


In [622]:
# utility function for local random seeding
@contextlib.contextmanager
def np_temp_seed(seed):
	state = np.random.get_state()
	np.random.seed(seed)
	try:
		yield
	finally:
		np.random.set_state(state)

In [623]:
def _safe_norm_jax(v):
    if not jnp.all(jnp.isfinite(v)):
        return jnp.inf
    return jnp.linalg.norm(v)

def scalar_search_armijo_jax(phi, phi0, derphi0, c1=1e-4, alpha0=1, amin=0):
    ite = 0
    phi_a0 = phi(alpha0)    # First do an update with step size 1
    if phi_a0 <= phi0 + c1*alpha0*derphi0:
        return alpha0, phi_a0, ite

    # Otherwise, compute the minimizer of a quadratic interpolant
    alpha1 = -(derphi0) * alpha0**2 / 2.0 / (phi_a0 - phi0 - derphi0 * alpha0)
    phi_a1 = phi(alpha1)

    # Otherwise loop with cubic interpolation until we find an alpha which
    # satisfies the first Wolfe condition (since we are backtracking, we will
    # assume that the value of alpha is not too small and satisfies the second
    # condition.
    while alpha1 > amin:       # we are assuming alpha>0 is a descent direction
        factor = alpha0**2 * alpha1**2 * (alpha1-alpha0)
        a = alpha0**2 * (phi_a1 - phi0 - derphi0*alpha1) - \
            alpha1**2 * (phi_a0 - phi0 - derphi0*alpha0)
        a = a / factor
        b = -alpha0**3 * (phi_a1 - phi0 - derphi0*alpha1) + \
            alpha1**3 * (phi_a0 - phi0 - derphi0*alpha0)
        b = b / factor

        alpha2 = (-b + jnp.sqrt(jnp.abs(b**2 - 3 * a * derphi0))) / (3.0*a)
        phi_a2 = phi(alpha2)
        ite += 1

        if (phi_a2 <= phi0 + c1*alpha2*derphi0):
            return alpha2, phi_a2, ite

        if (alpha1 - alpha2) > alpha1 / 2.0 or (1 - alpha2/alpha1) < 0.96:
            alpha2 = alpha1 / 2.0

        alpha0 = alpha1
        alpha1 = alpha2
        phi_a0 = phi_a1
        phi_a1 = phi_a2

    # Failed to find a suitable step length
    return None, phi_a1, ite


def line_search_jax(update, x0, g0, g, nstep=0, on=True):
    """
    `update` is the propsoed direction of update.

    Code adapted from scipy.
    """
    tmp_s = [0]
    tmp_g0 = [g0]
    tmp_phi = [jnp.linalg.norm(g0)**2]
    s_norm = jnp.linalg.norm(x0) / jnp.linalg.norm(update)

    def phi(s, store=True):
        if s == tmp_s[0]:
            return tmp_phi[0]    # If the step size is so small... just return something
        x_est = x0 + s * update
        g0_new = g(x_est)
        phi_new = _safe_norm_jax(g0_new)**2
        if store:
            tmp_s[0] = s
            tmp_g0[0] = g0_new
            tmp_phi[0] = phi_new
        return phi_new
    
    if on:
        s, phi1, ite = scalar_search_armijo_jax(phi, tmp_phi[0], -tmp_phi[0], amin=1e-2)
    if (not on) or s is None:
        s = 1.0
        ite = 0

    x_est = x0 + s * update
    if s == tmp_s[0]:
        g0_new = tmp_g0[0]
    else:
        g0_new = g(x_est)
    return x_est, g0_new, x_est - x0, g0_new - g0, ite



In [624]:
def rmatvec_jax(part_Us, part_VTs, x):
    # Compute x^T(-I + UV^T)
    # x: (N, 2d, L')
    # part_Us: (N, 2d, L', threshold)
    # part_VTs: (N, threshold, 2d, L')
    if jnp.size(part_Us) == 0:
        return -x
    xTU = jnp.einsum('bij, bijd -> bd', x, part_Us)   # (N, threshold)
    return -x + jnp.einsum('bd, bdij -> bij', xTU, part_VTs)    # (N, 2d, L'), but should really be (N, 1, (2d*L'))

def matvec_jax(part_Us, part_VTs, x):
    # Compute (-I + UV^T)x
    # x: (N, 2d, L')
    # part_Us: (N, 2d, L', threshold)
    # part_VTs: (N, threshold, 2d, L')
    if jnp.size(part_Us) == 0:
        return -x
    VTx = jnp.einsum('bdij, bij -> bd', part_VTs, x)  # (N, threshold)
    return -x + jnp.einsum('bijd, bd -> bij', part_Us, VTx)     # (N, 2d, L'), but should really be (N, (2d*L'), 1)


In [625]:
def broyden_jax(f, x0, threshold, eps=1e-3, stop_mode="rel", ls=False, name="unknown"):
    bsz, total_hsize, seq_len = x0.shape
    g = lambda y: f(y) - y
    dev = x0.device()
    alternative_mode = 'rel' if stop_mode == 'abs' else 'abs'
    
    x_est = x0           # (bsz, 2d, L')
    gx = g(x_est)        # (bsz, 2d, L')
    nstep = 0
    tnstep = 0
    
    # For fast calculation of inv_jacobian (approximately)
    Us = jax.device_put(jnp.zeros((bsz, total_hsize, seq_len, threshold)),dev)     # One can also use an L-BFGS scheme to further reduce memory
    VTs = jax.device_put(jnp.zeros((bsz, threshold, total_hsize, seq_len)),dev)
    update = -matvec_jax(Us[:,:,:,:nstep], VTs[:,:nstep], gx)      # Formally should be -torch.matmul(inv_jacobian (-I), gx)
    prot_break = False
    
    # To be used in protective breaks
    protect_thres = (1e6 if stop_mode == "abs" else 1e3) * seq_len
    new_objective = 1e8

    trace_dict = {'abs': [],
                  'rel': []}
    lowest_dict = {'abs': 1e8,
                   'rel': 1e8}
    lowest_step_dict = {'abs': 0,
                        'rel': 0}
    nstep, lowest_xest, lowest_gx = 0, x_est, gx

    while nstep < threshold:
        x_est, gx, delta_x, delta_gx, ite = line_search_jax(update, x_est, gx, g, nstep=nstep, on=ls)
        nstep += 1
        tnstep += (ite+1)
        abs_diff = jnp.linalg.norm(gx)
        rel_diff = abs_diff / (jnp.linalg.norm(gx + x_est) + 1e-9)
        diff_dict = {'abs': abs_diff,
                     'rel': rel_diff}
        trace_dict['abs'].append(abs_diff)
        trace_dict['rel'].append(rel_diff)
        for mode in ['rel', 'abs']:
            if diff_dict[mode] < lowest_dict[mode]:
                if mode == stop_mode: 
                    # lowest_xest, lowest_gx = lax.stop_gradient(x_est.copy()), lax.stop_gradient(gx.copy())
                    lowest_xest, lowest_gx = lax.stop_gradient(x_est), lax.stop_gradient(gx)
                lowest_dict[mode] = diff_dict[mode]
                lowest_step_dict[mode] = nstep

        new_objective = diff_dict[stop_mode]
        if new_objective < eps: break
        if new_objective < 3*eps and nstep > 30 and np.max(trace_dict[stop_mode][-30:]) / np.min(trace_dict[stop_mode][-30:]) < 1.3:
            # if there's hardly been any progress in the last 30 steps
            break
        if new_objective > trace_dict[stop_mode][0] * protect_thres:
            prot_break = True
            break

        part_Us, part_VTs = Us[:,:,:,:nstep-1], VTs[:,:nstep-1]
        vT = rmatvec_jax(part_Us, part_VTs, delta_x)
        u = (delta_x - matvec_jax(part_Us, part_VTs, delta_gx)) / jnp.einsum('bij, bij -> b', vT, delta_gx)[:,None,None]
        vT = jnp.nan_to_num(vT,nan=0.)
        u = jnp.nan_to_num(u,nan=0.)
        VTs = VTs.at[:,nstep-1].set(vT)
        Us = Us.at[:,:,:,nstep-1].set(u)
        update = -matvec_jax(Us[:,:,:,:nstep], VTs[:,:nstep], gx)

    # Fill everything up to the threshold length
    for _ in range(threshold+1-len(trace_dict[stop_mode])):
        trace_dict[stop_mode].append(lowest_dict[stop_mode])
        trace_dict[alternative_mode].append(lowest_dict[alternative_mode])

    return {"result": lowest_xest,
            "lowest": lowest_dict[stop_mode],
            "nstep": lowest_step_dict[stop_mode],
            "prot_break": prot_break,
            "abs_trace": trace_dict['abs'],
            "rel_trace": trace_dict['rel'],
            "eps": eps,
            "threshold": threshold}


def newton_jax(f, x0, threshold, eps=1e-3, stop_mode="rel", name="unknown"):

    g = lambda y: f(y) - y
    jac_g = jax.jacfwd(g)
    x = x0
    gx = g(x)
    gx_norm = jnp.linalg.norm(gx)
    nstep = 0
    # print(gx_norm)

    while nstep < threshold:
      # solve system
      delta_x = jnp.linalg.solve(jac_g(x),-g(x))
      x = x + delta_x
      gx = g(x)
      gx_norm = jnp.linalg.norm(gx)
      nstep += 1
      # print(gx_norm)

    return x, gx, gx_norm

In [626]:
class MDEQBlock(nn.Module):
    curr_branch: int
    channels: List[int]
    kernel_size: Tuple[int] = (3, 3)  # can also be (5, 5), modify later
    num_groups: int = 2
    kernel_init = jax.nn.initializers.glorot_normal()
    bias_init = jax.nn.initializers.glorot_normal()

    
    def setup(self):  
        self.input_dim = self.channels[self.curr_branch]
        self.hidden_dim = 2*self.input_dim

        # init-substitute for flax
        self.conv1 = nn.Conv(features=self.hidden_dim, kernel_size=self.kernel_size,
                             strides=1)#, kernel_init=self.kernel_init, bias_init=self.bias_init)
        self.group1 = nn.GroupNorm(num_groups=self.num_groups)
        self.relu = nn.relu
        self.conv2 = nn.Conv(features=self.input_dim, kernel_size=self.kernel_size,
                             strides=1)#, kernel_init=self.kernel_init, bias_init=self.bias_init)
        self.group2 = nn.GroupNorm(num_groups=self.num_groups)
        self.group3 = nn.GroupNorm(num_groups=self.num_groups)


    def __call__(self, x, branch, injection):
        # forward pass
        h1 = self.group1(self.conv1(x))
        h1 = self.relu(h1)
        h2 = self.conv2(h1)
        
        
        if branch == 0:
            h2 += injection
        
        h2 = self.group2(h2)
        h2 += x
        h3 = self.relu(h2)
        out = self.group3(h3)

        return out


    
''' 
    assert statement we'll need    
    assert that the number of branches == len(input_channel_vector)
    assert also that num_branches == len(kernel_size_vector)
'''

" \n    assert statement we'll need    \n    assert that the number of branches == len(input_channel_vector)\n    assert also that num_branches == len(kernel_size_vector)\n"

In [627]:
class DownSample(nn.Module):
    channels: List[int]
    branches: Tuple[int]
    num_groups: int
    kernel_init = jax.nn.initializers.glorot_normal()

    def _downsample(self):
        to_res, from_res = self.branches  # sampling from resolution from_res to to_res
        num_samples = to_res - from_res
        assert num_samples > 0

        down_block = []

        for n in range(num_samples):
            inter_chan = self.in_chan if n < num_samples-1 else self.out_chan
            conv_down = nn.Conv(features=inter_chan, kernel_size=(3, 3), strides=(2,2), padding=1, use_bias=False)
                               #, kernel_init=self.kernel_init, use_bias=False)
            group_down = nn.GroupNorm(num_groups=self.num_groups)
                                      #group_size=inter_chan)
            relu_down = nn.relu
            # module_list = [conv_down, group_down]
            if n < num_samples - 1:
                # module = nn.Sequential([conv_down,
                # module = [conv_down, group_down, relu_down]
                down_block += [conv_down, group_down, relu_down]
            else:
                # module = nn.Sequential([conv_down,
                # module = [conv_down, group_down]
                down_block += [conv_down, group_down]
            #down_block.append(module)
        print("down_block", down_block)
        seq = nn.Sequential(down_block)
        print('seq', seq)
        return seq

    def setup(self):
        self.in_chan = self.channels[self.branches[0]]
        self.out_chan  = self.channels[self.branches[1]]
        #self.downsample_fn = self._downsample()
        to_res, from_res = self.branches  # sampling from resolution from_res to to_res
        num_samples = to_res - from_res 
        assert num_samples > 0

        down_block = []
        print("num_samples", num_samples)
        for n in range(num_samples):
            print("nnnnnnn", n)
            inter_chan = self.in_chan if n < num_samples-1 else self.out_chan
            
            print("inter_chan", inter_chan)
            conv_down = nn.Conv(features=inter_chan, kernel_size=(3,3), strides=(2,2), padding=((1,1),(1,1)), use_bias=False)
                               #, kernel_init=self.kernel_init, use_bias=False)
            group_down = nn.GroupNorm(num_groups=self.num_groups)
                                      #group_size=inter_chan)
            relu_down = nn.relu
            # module_list = [conv_down, group_down]
            
            #down_block.append(conv_down)
            #down_block.append(group_down)
                 
            if n < num_samples - 1:
                # module = nn.Sequential([conv_down,
                # module = [conv_down, group_down, relu_down]
                down_block += [conv_down, group_down, relu_down]
            else:
                # module = nn.Sequential([conv_down,
                # module = [conv_down, group_down]
                down_block += [conv_down, group_down]
            
            #down_block.append(module)
            
        print("down_block", down_block)
        self.downsample_fn = nn.Sequential(down_block)
        #self.layers = down_block
        #print('seq', self.layers)

    def __call__(self, z_plus):
        print("zPlus", z_plus.shape)
        out = self.downsample_fn(z_plus)
        '''
        z = z_plus
        for i, lyr in enumerate(self.layers[:-1]):
            print(i, z.shape)
            z = lyr(z)
            z = nn.relu(z)
            print(i, z.shape)
            #z = nn.sigmoid(z)  # nn.silu(z)  # jnp.tanh(z)  # nn.sigmoid(z)
        out = self.layers[-1](z)
        '''
        print("out", out.shape)
        return out


In [628]:
class UpSample(nn.Module):
    channels: List[int]
    branches: Tuple[int]
    num_groups: int
    kernel_init = jax.nn.initializers.glorot_normal()

    
    def setup(self):
        self.in_chan = self.channels[self.branches[0]]
        self.out_chan = self.channels[self.branches[1]]
        self.upsample_fn = self._upsample()
        
    ''' the following is from https://github.com/google/jax/issues/862 '''
    
    def interpolate_bilinear(self, im, rows, cols):
        # based on http://stackoverflow.com/a/12729229
        col_lo = np.floor(cols).astype(int)
        col_hi = col_lo + 1
        row_lo = np.floor(rows).astype(int)
        row_hi = row_lo + 1

        nrows, ncols = im.shape[-3:-1]
        def cclip(cols): return np.clip(cols, 0, ncols - 1)
        def rclip(rows): return np.clip(rows, 0, nrows - 1)
        Ia = im[..., rclip(row_lo), cclip(col_lo), :]
        Ib = im[..., rclip(row_hi), cclip(col_lo), :]
        Ic = im[..., rclip(row_lo), cclip(col_hi), :]
        Id = im[..., rclip(row_hi), cclip(col_hi), :]

        wa = np.expand_dims((col_hi - cols) * (row_hi - rows), -1)
        wb = np.expand_dims((col_hi - cols) * (rows - row_lo), -1)
        wc = np.expand_dims((cols - col_lo) * (row_hi - rows), -1)
        wd = np.expand_dims((cols - col_lo) * (rows - row_lo), -1)

        return wa*Ia + wb*Ib + wc*Ic + wd*Id

    def upsampling_wrap(self, resize_rate):
        def upsampling_method(img):
            nrows, ncols = img.shape[-3:-1]
            delta = 0.5/resize_rate

            rows = np.linspace(delta,nrows-delta, np.int32(resize_rate*nrows))
            cols = np.linspace(delta,ncols-delta, np.int32(resize_rate*ncols))
            ROWS, COLS = np.meshgrid(rows,cols,indexing='ij')
        
            img_resize_vec = self.interpolate_bilinear(img, ROWS.flatten(), COLS.flatten())
            img_resize =  img_resize_vec.reshape(img.shape[:-3] + 
                                                (len(rows),len(cols)) + 
                                                img.shape[-1:])
        
            return img_resize
        return upsampling_method
    ''' end copy '''


    def _upsample(self):
        to_res, from_res = self.branches  # sampling from resolution from_res to to_res
        num_samples = from_res - to_res 
        assert num_samples > 0

        return nn.Sequential([nn.Conv(features=self.out_chan, kernel_size=(1, 1), use_bias=False), #kernel_init=self.kernel_init),
                        nn.GroupNorm(num_groups=self.num_groups),
                        self.upsampling_wrap(resize_rate=2**num_samples)])

    def __call__(self, z_plus):
        print("uptype", type(self.upsample_fn))
        return self.upsample_fn(z_plus)

In [629]:
class f_theta(nn.Module):
    num_branches: int
    channels: List[int]
    num_groups: int
    features: Tuple[int] = (16, 4)
    kernel_init = jax.nn.initializers.glorot_normal()
    
    # branches: List[int] = field(default_factory=lambda:[24, 24, 24])


    #  TODO HERE
    # UnfilteredStackTrace: jax.core.InconclusiveDimensionOperation: Cannot divide evenly the sizes of shapes (64, 32640, 1) and (64, 32, 32, 24)
    def cringy_reshape(self, in_vec, shape_list):
        start = 0
        out_vec = []
        in_vec = jnp.array(in_vec)
        for size in shape_list:
            my_elems = jnp.prod(jnp.array(size[1:]))
            end = start+my_elems
            my_chunk = in_vec[:, start:end]
            start += my_elems
            my_chunk = jnp.reshape(my_chunk, size)
            out_vec.append(my_chunk)

        return out_vec

    def setup(self):

        # self.downsample = DownSample(channels=self.channels,
        #                              num_groups=self.num_groups)
        # self.upsample = UpSample(channels=self.channels,
        #                          num_groups=self.num_groups)

        self.branches = self.stack_branches()

        self.fuse_branches = self.fuse()
        self.transform = self.transform_output()

    def stack_branches(self):
        branches = []
        for i in range(self.num_branches):
          branches.append(MDEQBlock(curr_branch=i, channels=self.channels))
        return branches

    def fuse(self):#, z_plus, channel_dimensions):
        # up- and downsampling stuff
        # z_plus: output of residual block
        if self.num_branches == 1:
            return None
        
        fuse_layers = []
        for i in range(self.num_branches):
            array = []
            for j in range(self.num_branches):
                if i == j:
                    # array.append(z_plus[i])
                    array.append(None)
                else:
                    if i > j:
                        sampled = DownSample(branches=(i, j), channels=self.channels, num_groups=self.num_groups)
                        #(z_plus=z_plus, branches=(i, j),
                                                 #channel_dimension=channel_dimensions)
                    elif i < j:
                        sampled = UpSample(branches=(i, j), channels=self.channels, num_groups=self.num_groups)
                        #(z_plus=z_plus, branches=(i, j),
                                                # channel_dimension=channel_dimensions)
                    # array.append(nn.Module(sampled))
                    array.append(sampled)
            # fuse_layers.append(nn.Module(array))
            fuse_layers.append(array)

        return fuse_layers
    
    def transform_output(self):
        transforms = []
        for i in range(self.num_branches):
          transforms.append(nn.Sequential([nn.relu,
                                          nn.Conv(features=self.channels[i], kernel_size=(1, 1),
                                                  #kernel_init=self.kernel_init, 
                                                  use_bias=False),
                                          nn.GroupNorm(num_groups=self.num_groups//2)]))
                                                       #group_size=self.channels[i])]))
        
        return transforms

    def __call__(self, x, injection, shape_list):
        x = self.cringy_reshape(x, shape_list)
        #print('preshape injection', injection.shape)
        #injection = self.cringy_reshape(injection, shape_list)
        # step 1: compute residual blocks
        branch_outputs = []
        
        for i in range(self.num_branches):
            branch_outputs.append(self.branches[i](x[i], i, injection[i])) # z, branch, x

        # step 2: fuse residual blocks
        fuse_outputs = []
        for i in range(self.num_branches):
          intermediate_i = jnp.zeros(branch_outputs[i].shape) 
          for j in range(self.num_branches):
            if i == j:
              intermediate_i += branch_outputs[j]
            else:
              print("i,j", i, j)
              if self.fuse_branches[i][j] is not None:
                  print(i,j)
                  print('bshape', branch_outputs[j].shape)
                  temp = self.fuse_branches[i][j](z_plus=branch_outputs[j])#, branches=(i, j))
                  print("temp", type(temp))
                  print("inter", type(intermediate_i))
                  intermediate_i += temp
                  print("inter", type(intermediate_i))
              else:
                  raise Exception("Should not happen.")
              #print('mimimi', self.fuse_branches[i][j])
              #intermediate_i += self.fuse_branches[i][j](branch_outputs[j])
          fuse_outputs.append(self.transform[i](intermediate_i))

        return fuse_outputs


    

In [630]:
class MDEQModel(nn.Module):
    solver_fn: Callable

    num_groups: int = 8
    channels: List[int] = field(default_factory=lambda:[24, 24, 24])
    branches: List[int] = field(default_factory=lambda:[1, 1, 1])
    training: bool = True
    kernel_init = jax.nn.initializers.glorot_normal()
    bias_init = jax.nn.initializers.glorot_normal()
    features: Tuple[int] = (16, 4)


    def setup(self):
        self.num_branches = len(self.branches)

        self.conv1 = nn.Conv(features=self.channels[0], 
                             kernel_size=(3,3), strides=(1,1))
        self.bn1 = nn.BatchNorm()
        self.relu = nn.relu
        self.conv2 = nn.Conv(features=self.channels[0], 
                             kernel_size=(3,3), strides=(1,1))
        self.bn2 = nn.BatchNorm()
        self.model = f_theta(num_branches=len(self.channels), channels=self.channels, num_groups=self.num_groups)
                                         
    def __call__(self, x):
        #x = self.transform(x)
        # x = self.relu(self.bn1(self.conv1(x), use_running_average=False))
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        # x = self.relu(self.bn2(self.conv2(x), use_running_average=False))
        
        x_list = [x]
        for i in range(1, self.num_branches):
            bs, H, W, _ = x_list[-1].shape
            new_item = jnp.zeros((bs, H//2, W//2, self.channels[i]))
            x_list.append(new_item)
        z_list = [jnp.zeros_like(elem) for elem in x_list]
        shape_list = [el.shape for el in z_list]
        
        bsz = x.shape[0]
        # func = lambda z: self.model(x=z, injection=x_list, shape_list=shape_list)
        def make_vec(in_vec):
            return jnp.concatenate([elem.reshape(bsz, -1, 1) for elem in in_vec], axis=1)
        def func(z):
            out = self.model(x=z, injection=x_list, shape_list=shape_list)
            return make_vec(out) 
        # z_vec = jnp.concatenate([elem.reshape(bsz, -1, 1) for elem in z_list], axis=1)
        z_vec = make_vec(z_list) 
        result = self.solver_fn(func, z_vec, threshold=3)
        z_vec = result['result']
        output = z_vec
        if self.training:
            output = func(z_vec)
            #output = func(z_vec.requires_grad_())
        # jac_loss = jac_loss_estimate(output, z1) # comes from the follow-up paper
        
        y_list = output # TO DO -- for now without dropout!
        print('y shape shape', y_list.shape)

        return y_list

In [631]:
def transform(image, label, num_classes=10):
    image = jnp.float32(image) / 255.
    # label = jax.nn.one_hot(label, num_classes=num_classes)
    label = jnp.array(label)
    return image, label

def load_data():
    test_ds = torchvision.datasets.CIFAR10(root="data", train=False,download=True)
    train_ds = torchvision.datasets.CIFAR10(root="data", train=True,download=True)

    train_images, train_labels = transform(train_ds.data[:1000], train_ds.targets[:1000])
    test_images, test_labels = transform(test_ds.data[:200], test_ds.targets[:200])
    return train_images, train_labels, test_images, test_labels

In [632]:
def forward_fn(head, mdeq, weights, images):
    '''
    mdeq: lambda function from below (taking weights and images as arguments)
    '''
    y_batch = mdeq(weights['mdeq'], images)
    logits = head.apply(weights['head'], y_batch)

    return logits

In [633]:
def train():
    '''
    extra thing: warm-up using gradient descent in pytorch code of official repo
    --> check impact of that and maybe also cost etc (eg if only one layer etc)
    '''

    max_itr = 1000
    print_interval = 100

    train_images, train_labels, test_images, test_labels = load_data()
    train_size = train_images.shape[0]
    batch_size = 64
    assert batch_size <= train_images.shape[0]

    solver_fn = broyden_jax

    my_model = MDEQModel(solver_fn=solver_fn)
    my_deq = lambda mdeq_weights, images: my_model.apply(mdeq_weights, images)

    # def cross_entropy_loss(*, logits, labels):
    def cross_entropy_loss(logits, labels):
        ''' 
        should be same as  optax.softmax_cross_entropy(logits, labels); 
        if getting funny results maybe remove log of logits
        '''
        one_hot_labels = jax.nn.one_hot(labels, num_classes=10)
        return -jnp.mean(jnp.sum(one_hot_labels * jnp.log(logits), axis=-1))

    png = jax.random.PRNGKey(0)
    png, _ = jax.random.split(png, 2)
    dummy_input = jnp.ones((batch_size, 32, 32, 3))
    cls_dummy_input = jnp.ones((32256,))

    print('calling model.init')
    mdeq_weights = my_model.init(rngs=png, x=dummy_input)
    png, _ = jax.random.split(png, 2)
    head = nn.Dense(10)
    cls_weights = head.init(rngs=png, x=cls_dummy_input)
    weights = {'mdeq': mdeq_weights, 'head': cls_weights}

    optimizer = optax.adamw(learning_rate=0.001, weight_decay=0.001)
    # optimizer = flax.optim.Adam(learning_rate=1e-3).create(weights)
    opt_state = optimizer.init(weights)

    loss_fn = cross_entropy_loss
    def loss(weights, x_batch, y_true):
        logits = forward_fn(head, my_deq, weights, x_batch)
        # y_batch = my_deq(weights, x_batch)
        return loss_fn(logits, y_true)
  
    def step(cls_weights, weights, opt_state, x_batch, y_true):
        print('learning how to walk')
        loss_vals, grad = jax.value_and_grad(loss, has_aux=True)(weights, x_batch, y_true)
        updates, opt_state = optimizer.update(grad, opt_state, weights)
        weights = optax.apply_updates(weights, updates)

        return weights, opt_state, loss_vals

    def generator(batch_size: int=10):
        ''' https://optax.readthedocs.io/en/latest/meta_learning.html?highlight=generator#meta-learning '''
        rng = jax.random.PRNGKey(0)

        while True:
            rng, k1 = jax.random.split(rng, num=2)
            idxs = jax.random.randint(k1, shape=(batch_size,), minval=0, maxval=train_size, dtype=jnp.int32)
            print('idxs', idxs)
            yield idxs

    def list_shuffler():
        rng = jax.random.PRNGKey(0)
        rng, k1 = jax.random.split(rng, num=2)
        indices = jnp.arange(0, train_images.shape[0])
        shuffled_indices = jax.random.shuffle(k1, indices)

        return shuffled_indices


    # g = generator(batch_size=batch_size)
    # print('g', g)

    
    for itr in range(max_itr):
        idxs = list_shuffler()
        start, end = 0, 0
           
        # batch_idxs = next(g)
        while end < len(idxs):
            print('start', start, 'end', end)
            end = min(start+batch_size, len(idxs))
            idxs_to_grab = idxs[start:end]
            print('start', start, 'end', end)
            x_batch = train_images[idxs_to_grab]
            y_true = train_labels[idxs_to_grab]
            start = end
            print('start', start, 'end', end)
  
            weights, opt_state, loss_vals = step(weights=weights,
                                                opt_state=opt_state,
                                                x_batch=x_batch,
                                                y_true=y_true)
            # loss_vals, grads = jax.value_and_grad(loss, has_aux=False)(optimizer.target, x_batch, y_true)
            # optimizer = optimizer.apply_gradient(grads)
            
            if itr % print_interval == 0:
                print("\tat step", itr, "have loss", loss_vals)

            if loss_vals < 1e-5:
                break
        

In [634]:
train()

Files already downloaded and verified
Files already downloaded and verified
calling model.init
i,j 0 1
0 1
bshape (64, 16, 16, 24)
uptype <class 'flax.linen.combinators.Sequential'>
temp <class 'jaxlib.xla_extension.DeviceArray'>
inter <class 'jaxlib.xla_extension.DeviceArray'>
inter <class 'jaxlib.xla_extension.DeviceArray'>
i,j 0 2
0 2
bshape (64, 8, 8, 24)
uptype <class 'flax.linen.combinators.Sequential'>
temp <class 'jaxlib.xla_extension.DeviceArray'>
inter <class 'jaxlib.xla_extension.DeviceArray'>
inter <class 'jaxlib.xla_extension.DeviceArray'>
i,j 1 0
1 0
bshape (64, 32, 32, 24)
num_samples 1
nnnnnnn 0
inter_chan 24
down_block [Conv(), GroupNorm(
    # attributes
    num_groups = 8
    group_size = None
    epsilon = 1e-06
    dtype = float32
    param_dtype = float32
    use_bias = True
    use_scale = True
    bias_init = zeros
    scale_init = ones
)]
zPlus (64, 32, 32, 24)
out (64, 16, 16, 24)
temp <class 'jaxlib.xla_extension.DeviceArray'>
inter <class 'jaxlib.xla_extensi

TypeError: ignored

Breakdown of code overall:


*   MDEQ modul
*   List item

