In [1]:
# install dependencies
!pip install flax optax



In [2]:
import jax
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)
jax.config.update("jax_check_tracer_leaks", True)
# import os
# os.environ["JAX_CHECK_TRACER_LEAKS"] = "1"
jax.checking_leaks = True
jax.check_tracer_leaks = True

import jax.lax as lax
from jax import random, jit

import optax

import flax
from flax import linen as nn

from functools import partial

import torchvision

import numpy as np  # TO DO remove np's -> jnp
import contextlib

from typing import Tuple, Union, List, OrderedDict, Callable
from dataclasses import field

# jaxopt has already implicit differentiation!!
import time

from matplotlib import pyplot as plt


In [3]:
# utility function for local random seeding
@contextlib.contextmanager
def np_temp_seed(seed):
	state = np.random.get_state()
	np.random.seed(seed)
	try:
		yield
	finally:
		np.random.set_state(state)

DEQ idea & finding stationary points with root finder, maybe root finder demo on small example (but that's close to copying from last year so maybe smth different?)

In [4]:
def _safe_norm_jax(v):
    if not jnp.all(jnp.isfinite(v)):
        return jnp.inf
    return jnp.linalg.norm(v)

def scalar_search_armijo_jax(phi, phi0, derphi0, c1=1e-4, alpha0=1, amin=0):
    ite = 0
    phi_a0 = phi(alpha0)    # First do an update with step size 1
    if phi_a0 <= phi0 + c1*alpha0*derphi0:
        return alpha0, phi_a0, ite

    # Otherwise, compute the minimizer of a quadratic interpolant
    alpha1 = -(derphi0) * alpha0**2 / 2.0 / (phi_a0 - phi0 - derphi0 * alpha0)
    phi_a1 = phi(alpha1)

    # Otherwise loop with cubic interpolation until we find an alpha which
    # satisfies the first Wolfe condition (since we are backtracking, we will
    # assume that the value of alpha is not too small and satisfies the second
    # condition.
    while alpha1 > amin:       # we are assuming alpha>0 is a descent direction
        factor = alpha0**2 * alpha1**2 * (alpha1-alpha0)
        a = alpha0**2 * (phi_a1 - phi0 - derphi0*alpha1) - \
            alpha1**2 * (phi_a0 - phi0 - derphi0*alpha0)
        a = a / factor
        b = -alpha0**3 * (phi_a1 - phi0 - derphi0*alpha1) + \
            alpha1**3 * (phi_a0 - phi0 - derphi0*alpha0)
        b = b / factor

        alpha2 = (-b + jnp.sqrt(jnp.abs(b**2 - 3 * a * derphi0))) / (3.0*a)
        phi_a2 = phi(alpha2)
        ite += 1

        if (phi_a2 <= phi0 + c1*alpha2*derphi0):
            return alpha2, phi_a2, ite

        if (alpha1 - alpha2) > alpha1 / 2.0 or (1 - alpha2/alpha1) < 0.96:
            alpha2 = alpha1 / 2.0

        alpha0 = alpha1
        alpha1 = alpha2
        phi_a0 = phi_a1
        phi_a1 = phi_a2

    # Failed to find a suitable step length
    return None, phi_a1, ite


def line_search_jax(update, x0, g0, g, nstep=0, on=True):
    """
    `update` is the propsoed direction of update.

    Code adapted from scipy.
    """
    tmp_s = [0]
    tmp_g0 = [g0]
    tmp_phi = [jnp.linalg.norm(g0)**2]
    s_norm = jnp.linalg.norm(x0) / jnp.linalg.norm(update)

    def phi(s, store=True):
        if s == tmp_s[0]:
            return tmp_phi[0]    # If the step size is so small... just return something
        x_est = x0 + s * update
        g0_new = g(x_est)
        phi_new = _safe_norm_jax(g0_new)**2
        if store:
            tmp_s[0] = s
            tmp_g0[0] = g0_new
            tmp_phi[0] = phi_new
        return phi_new
    
    if on:
        s, phi1, ite = scalar_search_armijo_jax(phi, tmp_phi[0], -tmp_phi[0], amin=1e-2)
    if (not on) or s is None:
        s = 1.0
        ite = 0

    x_est = x0 + s * update
    if s == tmp_s[0]:
        g0_new = tmp_g0[0]
    else:
        g0_new = g(x_est)
    return x_est, g0_new, x_est - x0, g0_new - g0, ite



In [5]:
def rmatvec_jax(part_Us, part_VTs, x):
    # Compute x^T(-I + UV^T)
    # x: (N, 2d, L')
    # part_Us: (N, 2d, L', threshold)
    # part_VTs: (N, threshold, 2d, L')
    if jnp.size(part_Us) == 0:
        return -x
    xTU = jnp.einsum('bij, bijd -> bd', x, part_Us)   # (N, threshold)
    return -x + jnp.einsum('bd, bdij -> bij', xTU, part_VTs)    # (N, 2d, L'), but should really be (N, 1, (2d*L'))

def matvec_jax(part_Us, part_VTs, x):
    # Compute (-I + UV^T)x
    # x: (N, 2d, L')
    # part_Us: (N, 2d, L', threshold)
    # part_VTs: (N, threshold, 2d, L')
    if jnp.size(part_Us) == 0:
        return -x
    VTx = jnp.einsum('bdij, bij -> bd', part_VTs, x)  # (N, threshold)
    return -x + jnp.einsum('bijd, bd -> bij', part_Us, VTx)     # (N, 2d, L'), but should really be (N, (2d*L'), 1)


In [6]:
def broyden_jax(g, x0, threshold, eps=1e-3, stop_mode="rel", result_dict=False, ls=False, name="unknown"):
    bsz, total_hsize, seq_len = x0.shape
    # g = lambda y: f(y) - y
    dev = x0.device()
    alternative_mode = 'rel' if stop_mode == 'abs' else 'abs'
    
    x_est = x0           # (bsz, 2d, L')
    gx = g(x_est)        # (bsz, 2d, L')
    nstep = 0
    tnstep = 0
    
    # For fast calculation of inv_jacobian (approximately)
    Us = jax.device_put(jnp.zeros((bsz, total_hsize, seq_len, threshold)),dev)     # One can also use an L-BFGS scheme to further reduce memory
    VTs = jax.device_put(jnp.zeros((bsz, threshold, total_hsize, seq_len)),dev)
    update = -matvec_jax(Us[:,:,:,:nstep], VTs[:,:nstep], gx)      # Formally should be -torch.matmul(inv_jacobian (-I), gx)
    prot_break = False
    
    # To be used in protective breaks
    protect_thres = (1e6 if stop_mode == "abs" else 1e3) * seq_len
    new_objective = 1e8

    trace_dict = {'abs': [],
                  'rel': []}
    lowest_dict = {'abs': 1e8,
                   'rel': 1e8}
    lowest_step_dict = {'abs': 0,
                        'rel': 0}
    nstep, lowest_xest, lowest_gx = 0, x_est, gx

    while nstep < threshold:
        x_est, gx, delta_x, delta_gx, ite = line_search_jax(update, x_est, gx, g, nstep=nstep, on=ls)
        nstep += 1
        tnstep += (ite+1)
        abs_diff = jnp.linalg.norm(gx)
        rel_diff = abs_diff / (jnp.linalg.norm(gx + x_est) + 1e-9)
        diff_dict = {'abs': abs_diff,
                     'rel': rel_diff}
        trace_dict['abs'].append(abs_diff)
        trace_dict['rel'].append(rel_diff)
        for mode in ['rel', 'abs']:
            if diff_dict[mode] < lowest_dict[mode]:
                if mode == stop_mode: 
                    # lowest_xest, lowest_gx = lax.stop_gradient(x_est.copy()), lax.stop_gradient(gx.copy())
                    # lowest_xest, lowest_gx = lax.stop_gradient(x_est), lax.stop_gradient(gx)
                    # lowest_xest, lowest_gx = lax.stop_gradient(jnp.copy(x_est)), lax.stop_gradient(jnp.copy(gx))
                    lowest_xest, lowest_gx = jnp.copy(x_est), jnp.copy(gx)
                lowest_dict[mode] = diff_dict[mode]
                lowest_step_dict[mode] = nstep

        new_objective = diff_dict[stop_mode]
        if new_objective < eps: break
        if new_objective < 3*eps and nstep > 30 and np.max(trace_dict[stop_mode][-30:]) / np.min(trace_dict[stop_mode][-30:]) < 1.3:
            # if there's hardly been any progress in the last 30 steps
            break
        if new_objective > trace_dict[stop_mode][0] * protect_thres:
            prot_break = True
            break

        part_Us, part_VTs = Us[:,:,:,:nstep-1], VTs[:,:nstep-1]
        vT = rmatvec_jax(part_Us, part_VTs, delta_x)
        u = (delta_x - matvec_jax(part_Us, part_VTs, delta_gx)) / jnp.einsum('bij, bij -> b', vT, delta_gx)[:,None,None]
        vT = jnp.nan_to_num(vT,nan=0.)
        u = jnp.nan_to_num(u,nan=0.)
        VTs = VTs.at[:,nstep-1].set(vT)
        Us = Us.at[:,:,:,nstep-1].set(u)
        update = -matvec_jax(Us[:,:,:,:nstep], VTs[:,:nstep], gx)

    # Fill everything up to the threshold length
    for _ in range(threshold+1-len(trace_dict[stop_mode])):
        trace_dict[stop_mode].append(lowest_dict[stop_mode])
        trace_dict[alternative_mode].append(lowest_dict[alternative_mode])

    if result_dict:
        return {"result": lowest_xest,
                "lowest": lowest_dict[stop_mode],
                "nstep": lowest_step_dict[stop_mode],
                "prot_break": prot_break,
                "abs_trace": trace_dict['abs'],
                "rel_trace": trace_dict['rel'],
                "eps": eps,
                "threshold": threshold}
    else:
        return lowest_xest


def newton_jax(g, x0, threshold, eps=1e-3, stop_mode="rel", result_dict=False, name="unknown"):

    # g = lambda y: f(y) - y
    jac_g = jax.jacfwd(g)
    x = x0
    gx = g(x)
    gx_norm = jnp.linalg.norm(gx)
    nstep = 0
    # print(gx_norm)

    while nstep < threshold:
      # solve system
      delta_x = jnp.linalg.solve(jac_g(x),-g(x))
      x = x + delta_x
      gx = g(x)
      gx_norm = jnp.linalg.norm(gx)
      nstep += 1
      # print(gx_norm)

    if result_dict:
        return {'result': x, 'gradient': gx, 'gx_norm': gx_norm}
    else:
        return x

Talking about MDEQ model...

In [7]:
class MDEQBlock(nn.Module):
    curr_branch: int
    channels: List[int]
    kernel_size: Tuple[int] = (3, 3)  # can also be (5, 5), modify later
    num_groups: int = 2
    kernel_init = jax.nn.initializers.glorot_normal()
    bias_init = jax.nn.initializers.glorot_normal()

    
    def setup(self):  
        self.input_dim = self.channels[self.curr_branch]
        self.hidden_dim = 2*self.input_dim

        # init-substitute for flax
        self.conv1 = nn.Conv(features=self.hidden_dim, kernel_size=self.kernel_size,
                             strides=1)#, kernel_init=self.kernel_init, bias_init=self.bias_init)
        self.group1 = nn.GroupNorm(num_groups=self.num_groups)
        self.relu = nn.relu
        self.conv2 = nn.Conv(features=self.input_dim, kernel_size=self.kernel_size,
                             strides=1)#, kernel_init=self.kernel_init, bias_init=self.bias_init)
        self.group2 = nn.GroupNorm(num_groups=self.num_groups)
        self.group3 = nn.GroupNorm(num_groups=self.num_groups)


    def __call__(self, x, branch, injection):
        # forward pass
        h1 = self.group1(self.conv1(x))
        h1 = self.relu(h1)
        h2 = self.conv2(h1)
        
        
        if branch == 0:
            h2 += injection
        
        h2 = self.group2(h2)
        h2 += x
        h3 = self.relu(h2)
        out = self.group3(h3)

        return out


    
''' 
    assert statement we'll need    
    assert that the number of branches == len(input_channel_vector)
    assert also that num_branches == len(kernel_size_vector)
'''

" \n    assert statement we'll need    \n    assert that the number of branches == len(input_channel_vector)\n    assert also that num_branches == len(kernel_size_vector)\n"

In [8]:
class DownSample(nn.Module):
    channels: List[int]
    branches: Tuple[int]
    num_groups: int
    kernel_init = jax.nn.initializers.glorot_normal()

    def _downsample(self):
        to_res, from_res = self.branches  # sampling from resolution from_res to to_res
        num_samples = to_res - from_res
        assert num_samples > 0

        down_block = []

        for n in range(num_samples):
            inter_chan = self.in_chan if n < num_samples-1 else self.out_chan
            conv_down = nn.Conv(features=inter_chan, kernel_size=(3, 3), strides=(2,2), padding=1, use_bias=False)
                               #, kernel_init=self.kernel_init, use_bias=False)
            group_down = nn.GroupNorm(num_groups=self.num_groups)
                                      #group_size=inter_chan)
            relu_down = nn.relu
            # module_list = [conv_down, group_down]
            if n < num_samples - 1:
                # module = nn.Sequential([conv_down,
                # module = [conv_down, group_down, relu_down]
                down_block += [conv_down, group_down, relu_down]
            else:
                # module = nn.Sequential([conv_down,
                # module = [conv_down, group_down]
                down_block += [conv_down, group_down]
            #down_block.append(module)
        seq = nn.Sequential(down_block)
        return seq

    def setup(self):
        self.in_chan = self.channels[self.branches[0]]
        self.out_chan  = self.channels[self.branches[1]]
        #self.downsample_fn = self._downsample()
        to_res, from_res = self.branches  # sampling from resolution from_res to to_res
        num_samples = to_res - from_res 
        assert num_samples > 0

        down_block = []
        for n in range(num_samples):
            inter_chan = self.in_chan if n < num_samples-1 else self.out_chan
            
            conv_down = nn.Conv(features=inter_chan, kernel_size=(3,3), strides=(2,2), padding=((1,1),(1,1)), use_bias=False)
                               #, kernel_init=self.kernel_init, use_bias=False)
            group_down = nn.GroupNorm(num_groups=self.num_groups)
                                      #group_size=inter_chan)
            relu_down = nn.relu
            # module_list = [conv_down, group_down]
            
            #down_block.append(conv_down)
            #down_block.append(group_down)
                 
            if n < num_samples - 1:
                # module = nn.Sequential([conv_down,
                # module = [conv_down, group_down, relu_down]
                down_block += [conv_down, group_down, relu_down]
            else:
                # module = nn.Sequential([conv_down,
                # module = [conv_down, group_down]
                down_block += [conv_down, group_down]
            
            #down_block.append(module)
            
        self.downsample_fn = nn.Sequential(down_block)
        #self.layers = down_block
        #print('seq', self.layers)

    def __call__(self, z_plus):
        out = self.downsample_fn(z_plus)
        '''
        z = z_plus
        for i, lyr in enumerate(self.layers[:-1]):
            print(i, z.shape)
            z = lyr(z)
            z = nn.relu(z)
            print(i, z.shape)
            #z = nn.sigmoid(z)  # nn.silu(z)  # jnp.tanh(z)  # nn.sigmoid(z)
        out = self.layers[-1](z)
        '''
        return out


In [9]:
class UpSample(nn.Module):
    channels: List[int]
    branches: Tuple[int]
    num_groups: int
    kernel_init = jax.nn.initializers.glorot_normal()

    
    def setup(self):
        self.in_chan = self.channels[self.branches[0]]
        self.out_chan = self.channels[self.branches[1]]
        self.upsample_fn = self._upsample()
        
    ''' the following is from https://github.com/google/jax/issues/862 '''
    
    def interpolate_bilinear(self, im, rows, cols):
        # based on http://stackoverflow.com/a/12729229
        col_lo = np.floor(cols).astype(int)
        col_hi = col_lo + 1
        row_lo = np.floor(rows).astype(int)
        row_hi = row_lo + 1

        nrows, ncols = im.shape[-3:-1]
        def cclip(cols): return np.clip(cols, 0, ncols - 1)
        def rclip(rows): return np.clip(rows, 0, nrows - 1)
        Ia = im[..., rclip(row_lo), cclip(col_lo), :]
        Ib = im[..., rclip(row_hi), cclip(col_lo), :]
        Ic = im[..., rclip(row_lo), cclip(col_hi), :]
        Id = im[..., rclip(row_hi), cclip(col_hi), :]

        wa = np.expand_dims((col_hi - cols) * (row_hi - rows), -1)
        wb = np.expand_dims((col_hi - cols) * (rows - row_lo), -1)
        wc = np.expand_dims((cols - col_lo) * (row_hi - rows), -1)
        wd = np.expand_dims((cols - col_lo) * (rows - row_lo), -1)

        return wa*Ia + wb*Ib + wc*Ic + wd*Id

    def upsampling_wrap(self, resize_rate):
        def upsampling_method(img):
            nrows, ncols = img.shape[-3:-1]
            delta = 0.5/resize_rate

            rows = np.linspace(delta,nrows-delta, np.int32(resize_rate*nrows))
            cols = np.linspace(delta,ncols-delta, np.int32(resize_rate*ncols))
            ROWS, COLS = np.meshgrid(rows,cols,indexing='ij')
        
            img_resize_vec = self.interpolate_bilinear(img, ROWS.flatten(), COLS.flatten())
            img_resize =  img_resize_vec.reshape(img.shape[:-3] + 
                                                (len(rows),len(cols)) + 
                                                img.shape[-1:])
        
            return img_resize
        return upsampling_method
    ''' end copy '''


    def _upsample(self):
        to_res, from_res = self.branches  # sampling from resolution from_res to to_res
        num_samples = from_res - to_res 
        assert num_samples > 0

        return nn.Sequential([nn.Conv(features=self.out_chan, kernel_size=(1, 1), use_bias=False), #kernel_init=self.kernel_init),
                        nn.GroupNorm(num_groups=self.num_groups),
                        self.upsampling_wrap(resize_rate=2**num_samples)])

    def __call__(self, z_plus):
        return self.upsample_fn(z_plus)

In [10]:
def cringy_reshape(in_vec, shape_list):
    start = 0
    out_vec = []
    in_vec = jnp.array(in_vec)
    for size in shape_list:
        my_elems = jnp.prod(jnp.array(size[1:]))
        end = start+my_elems
        my_chunk = in_vec[:, start:end]
        start += my_elems
        my_chunk = jnp.reshape(my_chunk, size)
        out_vec.append(my_chunk)

    return out_vec

def make_batch_vec(in_vec, bsz):
    return jnp.concatenate([elem.reshape(bsz, -1, 1) for elem in in_vec], axis=1)
        

In [11]:
class f_theta(nn.Module):
    num_branches: int
    channels: List[int]
    num_groups: int
    #features: Tuple[int] = (16, 4)
    kernel_init = jax.nn.initializers.glorot_normal()
    
    # branches: List[int] = field(default_factory=lambda:[24, 24, 24])


    #  TODO HERE
    # UnfilteredStackTrace: jax.core.InconclusiveDimensionOperation: Cannot divide evenly the sizes of shapes (64, 32640, 1) and (64, 32, 32, 24)
   

    def setup(self):

        # self.downsample = DownSample(channels=self.channels,
        #                              num_groups=self.num_groups)
        # self.upsample = UpSample(channels=self.channels,
        #                          num_groups=self.num_groups)

        self.branches = self.stack_branches()

        self.fuse_branches = self.fuse()
        self.transform = self.transform_output()

    def stack_branches(self):
        branches = []
        for i in range(self.num_branches):
          branches.append(MDEQBlock(curr_branch=i, channels=self.channels))
        return branches

    def fuse(self):#, z_plus, channel_dimensions):
        # up- and downsampling stuff
        # z_plus: output of residual block
        if self.num_branches == 1:
            return None
        
        fuse_layers = []
        for i in range(self.num_branches):
            array = []
            for j in range(self.num_branches):
                if i == j:
                    # array.append(z_plus[i])
                    array.append(None)
                else:
                    if i > j:
                        sampled = DownSample(branches=(i, j), channels=self.channels, num_groups=self.num_groups)
                        #(z_plus=z_plus, branches=(i, j),
                                                 #channel_dimension=channel_dimensions)
                    elif i < j:
                        sampled = UpSample(branches=(i, j), channels=self.channels, num_groups=self.num_groups)
                        #(z_plus=z_plus, branches=(i, j),
                                                # channel_dimension=channel_dimensions)
                    # array.append(nn.Module(sampled))
                    array.append(sampled)
            # fuse_layers.append(nn.Module(array))
            fuse_layers.append(array)

        return fuse_layers
    
    def transform_output(self):
        transforms = []
        for i in range(self.num_branches):
          transforms.append(nn.Sequential([nn.relu,
                                          nn.Conv(features=self.channels[i], kernel_size=(1, 1),
                                                  #kernel_init=self.kernel_init, 
                                                  use_bias=False),
                                          nn.GroupNorm(num_groups=self.num_groups//2)]))
                                                       #group_size=self.channels[i])]))
        
        return transforms

    def __call__(self, x, injection, shape_list):
        if not isinstance(x, list):
            x = cringy_reshape(x, shape_list)
        #print('preshape injection', injection.shape)
        #injection = self.cringy_reshape(injection, shape_list)
        # step 1: compute residual blocks
        branch_outputs = []
        
        for i in range(self.num_branches):
            branch_outputs.append(self.branches[i](x[i], i, injection[i])) # z, branch, x

        # step 2: fuse residual blocks
        fuse_outputs = []
        for i in range(self.num_branches):
            intermediate_i = jnp.zeros(branch_outputs[i].shape) 
            for j in range(self.num_branches):
                if i == j:
                  intermediate_i += branch_outputs[j]
                else:
                    if self.fuse_branches[i][j] is not None:
                        temp = self.fuse_branches[i][j](z_plus=branch_outputs[j])#, branches=(i, j))
                        intermediate_i += temp
                    else:
                        raise Exception("Should not happen.")
                #print('mimimi', self.fuse_branches[i][j])
                #intermediate_i += self.fuse_branches[i][j](branch_outputs[j])
            fuse_outputs.append(self.transform[i](intermediate_i))

        return fuse_outputs


    

In [12]:
class InitLayer(nn.Module):
    # solver_fn: Callable

    num_groups: int = 8
    channels: List[int] = field(default_factory=lambda:[24, 24])
    branches: List[int] = field(default_factory=lambda:[1, 1])
    #channels: List[int] = field(default_factory=lambda:[24, 24, 24])
    #branches: List[int] = field(default_factory=lambda:[1, 1, 1])
    training: bool = True
    kernel_init = jax.nn.initializers.glorot_normal()
    bias_init = jax.nn.initializers.glorot_normal()
    features: Tuple[int] = (16, 4)

    def setup(self):
        self.num_branches = len(self.branches)

        self.conv1 = nn.Conv(features=self.channels[0], 
                             kernel_size=(3, 3), strides=(1, 1))
        self.bn1 = nn.BatchNorm()
        self.relu = nn.relu
        self.conv2 = nn.Conv(features=self.channels[0], 
                             kernel_size=(3, 3), strides=(1, 1))
        self.bn2 = nn.BatchNorm()
        # self.model = f_theta(num_branches=len(self.channels), channels=self.channels, num_groups=self.num_groups)

    def __call__(self, x):
        print('in model', x.shape)
        x = self.relu(self.bn1(self.conv1(x), use_running_average=True))
        x = self.relu(self.bn2(self.conv2(x), use_running_average=True))
        
        x_list = [x]
        for i in range(1, self.num_branches):
            bs, H, W, _ = x_list[-1].shape
            new_item = jnp.zeros((bs, H//2, W//2, self.channels[i]))
            x_list.append(new_item)
        z_list = [jnp.zeros_like(elem) for elem in x_list]
        shape_list = [el.shape for el in z_list]
        # bsz = x.shape[0]
        # func = lambda z: self.model(x=z, injection=x_list, shape_list=shape_list)
        # def make_batch_vec(in_vec, bsz):
        #     return jnp.concatenate([elem.reshape(bsz, -1, 1) for elem in in_vec], axis=1)
            
        # def f_fn(z):
        #     out = self.model(x=make_vec(z), injection=x_list, shape_list=shape_list)
        #     print('outshapes', [o.shape for o in out])
        #     # return make_vec(out) - z
        #     return make_vec(out)

        return x_list, z_list, shape_list

        
        # z_vec = jnp.concatenate([elem.reshape(bsz, -1, 1) for elem in z_list], axis=1)
        #z_vec = make_vec(z_list) 

        ''' do forward backward thing here '''
        # z_vec = self._solve_mdeq(f_fn, x0, threshold=3)
        # forward, backward = self.solver_fn[0], self.solver_fn[1]

        # result = self.solver_fn(func, z_vec, threshold=3) # TODO have threshold as param
         
        # z_vec = result['result']
        # output = z_vec
        '''
        if self.training:
            output = func(z_vec)
        '''
            #output = func(z_vec.requires_grad_())
        # jac_loss = jac_loss_estimate(output, z1) # comes from the follow-up paper
        
        # y_list = cringy_reshape(output, shape_list) # TO DO -- for now without dropout!
        
        # return z_vec, f_fn, shape_list 

In [13]:
class CLSBlock(nn.Module):
    input_dim: int
    output_dim: int
    downsample: bool
    expansion: int=4
    
    def setup(self):  


        # init-substitute for flax
        self.conv1 = nn.Conv(features=self.output_dim, kernel_size=(1,1),
                             strides=(1,1))#, kernel_init=self.kernel_init, bias_init=self.bias_init)
        self.bn1 = nn.BatchNorm()
        self.relu = nn.relu
        self.conv2 = nn.Conv(features=self.output_dim, kernel_size=(3,3), strides=(1,1))#, kernel_init=self.kernel_init, bias_init=self.bias_init)
        self.bn2 = nn.BatchNorm()
        self.conv3 = nn.Conv(features=self.output_dim*self.expansion, kernel_size=(1,1), strides=(1,1))#, kernel_init=self.kernel_init, bias_init=self.bias_init)
        self.bn3 = nn.BatchNorm()

        if self.downsample:
            self.ds_conv = nn.Conv(self.output_dim*self.expansion, kernel_size=(1,1), strides=(1,1), use_bias=False)
            self.ds_bn = nn.BatchNorm()


    def __call__(self, x, injection=None):
        # forward pass
        if injection is None:
          injection = 0
        h1 = self.bn1(self.conv1(x), use_running_average=True)
        h1 = self.relu(h1)
        h2 = self.bn2(self.conv2(h1), use_running_average=True)
        h2 = self.relu(h2)
        h3 = self.bn3(self.conv3(h2), use_running_average=True)
        if self.downsample:
          x = self.ds_bn(self.ds_conv(x), use_running_average=True)
        h3 += x
        return nn.relu(h3)

...

In [14]:
class Classifier(nn.Module):
    #channels: List[int] = field(default_factory=lambda:[24, 24, 24])
    channels: List[int] = field(default_factory=lambda:[24, 24])
    #output_channels: List[int] = field(default_factory=lambda:[4, 8, 16])
    output_channels: List[int] = field(default_factory=lambda:[8, 16])
    expansion: int = 4
    final_chansize: int = 200
    num_classes: int = 10

    def _make_layer(self, inplanes, planes):
          downsample = False
          if inplanes != planes * self.expansion:
              downsample = True
          return CLSBlock(inplanes, planes, downsample)

    def setup(self):
        self.num_branches = len(self.channels)

        incre_modules = []
        for i, channels  in enumerate(self.channels):
            incre_mod = self._make_layer(self.channels[i], self.output_channels[i])
            incre_modules.append(incre_mod)
        self.incre_modules = incre_modules
        downsamp_modules = []
        for i in range(len(self.channels)-1):
            in_channels = self.output_channels[i] * self.expansion
            out_channels = self.output_channels[i+1] * self.expansion
            downsamp_module = nn.Sequential([nn.Conv(out_channels, kernel_size=(3,3), strides=(2,2), use_bias=True),
                                            #nn.BatchNorm(),
                                            nn.relu])
            downsamp_modules.append(downsamp_module)
        self.downsamp_modules = downsamp_modules

        self.final_layer = nn.Sequential([nn.Conv(self.final_chansize, kernel_size=(1,1)),
                                         #nn.BatchNorm(),
                                         nn.relu])
        self.classifier = nn.Dense(self.num_classes)
                                         
    def __call__(self, y_list):
        y = self.incre_modules[0](y_list[0])
        for i in range(len(self.downsamp_modules)):
            y = self.incre_modules[i+1](y_list[i+1]) + self.downsamp_modules[i](y)
        y = self.final_layer(y)
        y = nn.avg_pool(y, window_shape=y.shape[1:3])
        y = jnp.reshape(y, (y.shape[0], -1))
        y = self.classifier(y)
        return y

loading MNIST data...

In [15]:
def transform(image, label, num_classes=10):
    image = jnp.float32(image) / 255.
    label = jnp.array(label)
    return image, label

def load_data():
    test_ds = torchvision.datasets.CIFAR10(root="data", train=False,download=True)
    train_ds = torchvision.datasets.CIFAR10(root="data", train=True,download=True)

    train_images, train_labels = transform(train_ds.data[:5000], train_ds.targets[:5000])
    test_images, test_labels = transform(test_ds.data[:200], test_ds.targets[:200])
    return train_images, train_labels, test_images, test_labels

Assembling stuff: 
*   implicit differentiation stuff 
*   ... 





In [16]:
'''
Rootfinder forward
'''

# def broyden_jax(f, x0, threshold, eps=1e-3, stop_mode="rel", ls=False, name="unknown"):

@partial(jax.custom_vjp, nondiff_argnums=(0, 1, 2, 3,)) # nondiff are all except for weights and z/x
def rootfind_fwd(solver_fn: Callable,
                 f_fn: Callable,
                 threshold: int,
                 eps: float,
                 weights: dict,
                 x0: jnp.ndarray):
    print('\t\tat pos 00')
    print('x0 type inside rootfind fwd', type(x0))
    f_fn = partial(f_fn, weights)
    print('\t\tat pos 11')
    #print('kn forward', x0.shape)
    print('\t\tat pos 22')
    return jax.lax.stop_gradient(solver_fn(f_fn, x0, threshold, eps=1e-3))

# Its forward call (basically just calling it)
def _fwd_rootfind_fwd(solver_fn: Callable,
                      f_fn: Callable,
                      threshold: int,
                      eps: float,
                      weights: dict,
                      z: jnp.ndarray):
    z = rootfind_fwd(solver_fn, f_fn, threshold, eps, weights, z)
    return z, (weights, z)

# Its backward call (its inputs)
def _fwd_rootfind_bwd(solver_fn: Callable,
                      f_fn: Callable,
                      threshold: int,
                      eps: float,
                      res,  
                      grad):
    # weights, z = res
    
    return None, grad 
    # return None, None, None, None, None, grad 

rootfind_fwd.defvjp(_fwd_rootfind_fwd, _fwd_rootfind_bwd)

'''
Rootfinder backward
'''
@partial(jax.custom_vjp, nondiff_argnums=(0, 1, 2, 3,))
def rootfind_bwd(solver_fn: Callable,
                 f_fn: Callable,
                 threshold: int,
                 eps: float,
                 weights: dict,
                 x0: jnp.ndarray):
    print('rootfind_bwd seems ok')
    f_fn = partial(f_fn, weights)
    return jax.lax.stop_gradient(solver_fn(f_fn, x0, threshold, eps))

def _fwd_rootfind_bwd(solver_fn: Callable,
                      f_fn: Callable,
                      threshold: int,
                      eps: float,
                      weights: dict,
                      z: jnp.ndarray):
    print('_fwd_rootfind_bwd seems ok')
    return z, (weights, z)

def _bwd_rootfind_bwd(solver_fn: Callable,
                      f_fn: Callable,
                      threshold: int,
                      eps: float,
                      res,
                      grad):
    print('ok11')
    weights, z = res
    _, vjp_fun = jax.vjp(f_fn, weights, z)
    print('ok13')

    def x_fn(x): # gets transpose Jac w.r.t. weights and z using vjp_fun
        Jw_T, Jz_T = vjp_fun(x)
        return Jz_T + grad

    print('ok13')
    g_0 = jnp.zeros_like(grad)
    g = solver_fn(x_fn, g_0, threshold, eps)

    return None, g
    # return None, None, None, None, None, g 

rootfind_bwd.defvjp(_fwd_rootfind_bwd, _bwd_rootfind_bwd)



In [17]:
def solve_mdeq(solver_fn, f_fn, shape_list, weights, x0, threshold, eps):
    g_fn = lambda weights, z: f_fn(weights, z) - z
    
    print('\tat pos 0')
    print('x0 type atm', type(x0))
    print('bsz2e', shape_list[0][0])
    x0 = make_batch_vec(x0, shape_list[0][0])
    print('x0 type as vec', type(x0))
    # z_vec_star = rootfind_fwd(solver_fn, f_fn, threshold, eps, weights, cringy_reshape(x0, shape_list))
    z_vec_star = rootfind_fwd(solver_fn, f_fn, threshold, eps, weights, x0)
    print('\tat pos 1')
    print('zshape and shapelist', z_vec_star.shape, shape_list)
    # z_star = cringy_reshape(z_vec_star, shape_list)
    print('\tat pos 2')
    z_vec_star = f_fn(weights, z_vec_star)
    print('\tat pos 3')
    z_vec_star = rootfind_bwd(solver_fn, f_fn, threshold, eps, weights, x0)
    print('\tat pos 4')
    
    return z_vec_star # TODO cringy reshape


In [18]:
def forward_fn(solver_fn: Callable, head: nn.Module, init: nn.Module, f_th: nn.Module, weights: dict, images: jnp.ndarray):
    '''
    mdeq: lambda function from below (taking weights and images as arguments)
    '''

    # model_fn return x_list, z_list, shape_list
    init_fn = lambda init_weights, images: init.apply(init_weights, images)
    # def f_fn(mdeq_weights, images):
    #     z, _ = mdeq_fn(mdeq_weights, images)
    #     return z


    threshold = 3
    eps = 1e-3

    x_list, z_list, shape_list = init_fn(weights['init'], images) # take care of shape_list

    #  f_theta_fn takes x, injection, shape_list and returns list of fused layers (list of lists)
    def f_theta_fn(weights, z):
        out = f_th.apply(weights, z, x_list, shape_list)
        return  make_batch_vec(out, z_list[0].shape[0])

    # print("shape out of fn: ", z.shape)
    z_star = solve_mdeq(solver_fn=solver_fn, f_fn=f_theta_fn, shape_list=shape_list, weights=weights['f_theta'], x0=z_list, threshold=threshold, eps=eps)
    print('solved mdeq model')
    y_batch = cringy_reshape(z_star, shape_list)
    # y_batch = mdeq(weights['mdeq'], images)
    logits = head.apply(weights['head'], y_batch)
    return logits

In [19]:
def train():
    '''
    extra thing: warm-up using gradient descent in pytorch code of official repo
    --> check impact of that and maybe also cost etc (eg if only one layer etc)
    '''

    max_itr = 7 
    print_interval = 5

    train_images, train_labels, test_images, test_labels = load_data()
    train_size = train_images.shape[0]
    batch_size = 128
    assert batch_size <= train_images.shape[0]

    solver_fn = broyden_jax

    my_init = InitLayer()
    my_f_theta = f_theta(num_branches=len(my_init.channels), channels=my_init.channels, num_groups=my_init.num_groups)


    # def cross_entropy_loss(*, logits, labels):
    def cross_entropy_loss(logits, labels):
        ''' 
        should be same as  optax.softmax_cross_entropy(logits, labels); 
        if getting funny results maybe remove log of logits
        '''
        one_hot_labels = jax.nn.one_hot(labels, num_classes=10)
        logits = jax.nn.log_softmax(logits)
        acc = (jnp.argmax(logits, axis=-1) == labels).mean()
        output = -jnp.mean(jnp.sum(one_hot_labels * logits, axis=-1))
        return output, acc

    png = jax.random.PRNGKey(0)
    # png, _ = jax.random.split(png, 2)
    init_dummy_input = jnp.ones((batch_size, 32, 32, 3))
    f_theta_dummy_input = jnp.ones((batch_size, 32, 32, 24))
    cls_dummy_input = [jnp.ones((batch_size, 32, 32, 24)),
                       jnp.ones((batch_size, 16, 16, 24))]
                       #jnp.ones((64, 8, 8, 24)),]
    (dummy_x, dummy_inj, dummy_shape), init_weights = my_init.init_with_output(rngs=png, x=init_dummy_input)
    print('shape_list', dummy_shape)
    dummy_x = make_batch_vec(dummy_x, batch_size)
    print('test')
    # png, _ = jax.random.split(png, 2)
    f_theta_weights = my_f_theta.init(rngs=png, x=dummy_x, injection=dummy_inj, shape_list=dummy_shape)
    # png, _ = jax.random.split(png, 2)
    head = Classifier()
    cls_weights = head.init(png, cls_dummy_input)
    weights = {'init': init_weights, 'f_theta': f_theta_weights ,'head': cls_weights}

    optimizer = optax.adam(learning_rate=0.001)
    opt_state = optimizer.init(weights)

    loss_fn = cross_entropy_loss

    def loss(weights, x_batch, y_true):
        logits = forward_fn(solver_fn, head, my_init, my_f_theta, weights, x_batch)
        return loss_fn(logits, y_true)

    def step(weights, opt_state, x_batch, y_true):
        (loss_vals, acc), grad = jax.value_and_grad(loss, has_aux=True)(weights, x_batch, y_true)
        updates, opt_state = optimizer.update(grad, opt_state, weights)
        weights = optax.apply_updates(weights, updates)

        return weights, opt_state, loss_vals, acc

    def generator(batch_size: int=10):
        ''' https://optax.readthedocs.io/en/latest/meta_learning.html?highlight=generator#meta-learning '''
        rng = jax.random.PRNGKey(0)

        while True:
            rng, k1 = jax.random.split(rng, num=2)
            idxs = jax.random.randint(k1, shape=(batch_size,), minval=0, maxval=train_size, dtype=jnp.int32)
            yield idxs

    def list_shuffler():
        rng = jax.random.PRNGKey(0)
        rng, k1 = jax.random.split(rng, num=2)
        indices = jnp.arange(0, train_images.shape[0])
        shuffled_indices = jax.random.shuffle(k1, indices)

        return shuffled_indices

    max_epoch = 7 
    print_interval = 1

    train_log, val_log = [],[]
    for epoch in range(max_epoch):
        idxs = list_shuffler()
        start, end = 0, 0
        loss_vals = []
        acc_vals = []
        while end < len(idxs):
            end = min(start+batch_size, len(idxs))
            idxs_to_grab = idxs[start:end]
            x_batch = train_images[idxs_to_grab]
            y_true = train_labels[idxs_to_grab]
            start = end
  
            weights, opt_state, batch_loss, batch_acc = step(weights=weights,
                                                 opt_state=opt_state,
                                                 x_batch=x_batch,
                                                 y_true=y_true)
            loss_vals.append(batch_loss)
            acc_vals.append(batch_acc)


            print(f"batch_loss :: {batch_loss} // batch_acc :: {batch_acc}")
            # loss_vals, grads = jax.value_and_grad(loss, has_aux=False)(optimizer.target, x_batch, y_true)
            # optimizer = optimizer.apply_gradient(grads)
            
        epoch_loss = jnp.average(jnp.array(loss_vals))
        epoch_acc = jnp.average(jnp.array(acc_vals))

        if epoch % print_interval == 0:
            print("\tat epoch", epoch, "have loss", epoch_loss, "and acc", epoch_acc)

        if epoch_loss < 1e-5:
            break

            print('finally', batch_loss) 

Breakdown of code overall:


*   MDEQ modul
*   List item



In [20]:
train()

Files already downloaded and verified
Files already downloaded and verified




in model (128, 32, 32, 3)
shape_list [(128, 32, 32, 24), (128, 16, 16, 24)]
test


the following:

  import threading
  threading.current_thread().pydev_do_not_trace = True



in model (128, 32, 32, 3)
	at pos 0
x0 type atm <class 'list'>
bsz2e 128
x0 type as vec <class 'jaxlib.xla_extension.DeviceArray'>
		at pos 00
x0 type inside rootfind fwd <class 'jaxlib.xla_extension.DeviceArray'>
		at pos 11
		at pos 22
	at pos 1
zshape and shapelist (128, 30720, 1) [(128, 32, 32, 24), (128, 16, 16, 24)]
	at pos 2
	at pos 3
_fwd_rootfind_bwd seems ok
	at pos 4
solved mdeq model


Exception: ignored