In [60]:
import tensorflow as tf
import numpy as np
from keras.layers import Dense

b, w, h, ch, dim, l = 1, 2, 2, 2, 2, 4
heads = 2

img = tf.range(0, b * w * h * ch, dtype=float)
img = tf.reshape(img, (b, w * h, ch))
# so now the image is flattened to a sequence of tokens and every value of the array is unique
k = tf.ones((dim, l))

proj = tf.ones((heads * dim, dim))

qkv = tf.ones((ch, dim * 2 * heads))
# here we do the first einsum which should convert every token to a new dimension.  This should be symmetric in the heads dimension.  If we slice this array along the channel dimension it should be symmetric
inter = tf.einsum('bnc,cd->bnd', img, qkv)
# this is where q and v should come from
print(inter[:, :, :dim*heads].shape, inter[:, :, dim*heads:].shape)
print(np.array_equal(inter[:, :, :dim*heads], inter[:, :, dim*heads:]))
print(inter.shape)
qv = tf.reshape(inter, (b, w * h, heads, 2, dim))
#TODO: verify this reshape is working as expected
print(qv.shape)
qv = tf.transpose(qv, perm=[3, 0, 1, 2, 4])
print(qv.shape)

q, v = qv[0], qv[1]
print(np.array_equal(q, v))
print(q.shape, v.shape)

attn = tf.einsum('bnik,kr->bnir', q, k)
print(attn.shape)

res = tf.einsum('bnir,bnik->brik', attn, v)
print(res.shape)

#TODO: verify this reshape is working as expected
out = tf.reshape(res, (b, l, heads * dim))
print(out[:, :, :dim*heads].shape, out[:, :, dim*heads:].shape)
print(np.array_equal(out[:, :, :dim], out[:, :, dim:]))
print(out.shape)
out = tf.einsum('bnc,cd->bnd', out, proj)
print(out.shape)

(1, 4, 4) (1, 4, 4)
True
(1, 4, 8)
(1, 4, 2, 2, 2)
(2, 1, 4, 2, 2)
True
(1, 4, 2, 2) (1, 4, 2, 2)
(1, 4, 2, 4)
(1, 4, 2, 2)
(1, 4, 4) (1, 4, 0)
True
(1, 4, 4)
(1, 4, 2)


In [9]:
r = tf.range(0, 8)
r = tf.reshape(r, (2, 2, 2))
# so, the last axes are filled like digits from the back
# when the last axis reaches max index the next axis is iterated

In [10]:
r[0]

<tf.Tensor: shape=(2, 2), dtype=int32, numpy=
array([[0, 1],
       [2, 3]], dtype=int32)>

In [11]:
print(r[0,0])

tf.Tensor([0 1], shape=(2,), dtype=int32)


In [12]:
from tensorflow.keras.layers import Dropout

class LunchboxMHSA(tf.keras.layers.Layer):
    def __init__(self,
                 dim,
                 num_heads,
                 lunchbox_dim,
                 packed=False,
                 qkv_bias=True,
                 qk_scale=None,
                 proj_drop=0.,
                 prefix=''):
        """
        Dot product self-attention where the K table is a matrix of learnable
        parameters.  We use the 'Lunchbox' metaphor within this layer where the 
        lunchbox is K, and the savory lunchtime treats are the columns of K.
        Packing is intended to refer to a stage of transfer learning wherein the
        weights that form K are learned from an external task.  The lunchbox
        is considered 'packed' after learning on the external task, and 'unpacked'
        during training on the external task.  Call .pack() to freeze K, call
        .unpack() to unfreeze them.

        :param dim: embedding dimension for k, q, v
        :param num_heads: number of attention heads
        :param packed: whether or not the lunchbox should be considered packed (True -> k is not trainable)
        :param qkv_bias: whether or not to use bias in the initial projection to dim
        :param qk scale: scale for rescaling as show in https://arxiv.org/abs/1706.03762, defaults to dim ** -0.5
        :param proj_drop: dropout rate for output
        :param prefix: name of this layer
        """
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.lunchbox_dim = lunchbox_dim
        self.qkv_bias = qkv_bias

        self.scale = qk_scale or max(dim, lunchbox_dim) ** -0.5
        self.prefix = prefix

        self.qkv = None

        self.proj = None
        self.proj_drop = Dropout(proj_drop)
        self.k = None
        self.built = False
        self.packed = packed

    def build(self, input_shape):
        self.k = self.add_weight(f'{self.prefix}/attn/lunchbox',
                                 shape=(self.dim, self.lunchbox_dim),
                                 initializer=tf.initializers.GlorotUniform(), 
                                 trainable=self.packed)
        
        self.proj = self.add_weight(f'{self.prefix}/attn/proj',
                                 shape=(self.num_heads * self.dim, self.dim),
                                 initializer=tf.initializers.GlorotUniform(), 
                                 trainable=True)

        self.qkv = self.add_weight(f'{self.prefix}/attn/qkv',
                          shape=(input_shape[-1], self.dim * 2 * self.num_heads),
                          initializer=tf.initializers.GlorotUniform(), 
                          trainable=True)

        self.built = True


    def pack(self):
      # make the table untrainable
      self.k = tf.Variable(self.k, trainable=False)

    def unpack(self):
      # make the table trainable
      self.k = tf.Variable(self.k, trainable=True)

    def call(self, x):

        B_, N, C = x.get_shape().as_list()

        # x = tf.reshape(x, (B_, N, C))
        # (b, n, ch), (ch, dim) -> (b, n, dim * 2 * h)
        x = tf.einsum('bnc,cd->bnd', x, self.qkv)
        x = tf.transpose(x, perm=[0, 2, 1])

        x = tf.reshape(x, (-1, 2, self.num_heads, self.dim, N))
        qv = tf.transpose(x, perm=[1, 0, 4, 2, 3])

        q, v = qv[0], qv[1]

        attn = tf.einsum('bnik,kr->bnir', q, self.k)
        attn = attn * self.scale
        tf.nn.softmax(attn, -1)

        res = tf.einsum('bnir,bnik->brik', attn, v)

        x = tf.reshape(res, (-1, self.lunchbox_dim, self.num_heads * self.dim))

        x = self.proj(x)
        # (b, n, dim * n_h), (ch, dim) -> (b, n, dim)
        x = tf.einsum('bnc,cd->bnd', x, self.proj)

        x = self.proj_drop(x)

        return x

    def get_config(self):
        return {"k": self.k.numpy()}

In [13]:
lb = LunchboxMHSA(2,2,2)
lb.build((1, ))

In [14]:
lb.trainable_weights

[<tf.Variable '/attn/proj:0' shape=(4, 2) dtype=float32, numpy=
 array([[-0.5814152 , -0.8283603 ],
        [ 0.12946153,  0.5329387 ],
        [-0.13148117, -0.22390294],
        [-0.5452256 ,  0.9112253 ]], dtype=float32)>,
 <tf.Variable '/attn/qkv:0' shape=(1, 8) dtype=float32, numpy=
 array([[-0.58784664,  0.3195368 , -0.45753336, -0.15700167, -0.6013385 ,
          0.49908793, -0.26111907,  0.17951947]], dtype=float32)>]

In [15]:
weights = lb.get_weights()
lb.get_weights()

[array([[-0.5814152 , -0.8283603 ],
        [ 0.12946153,  0.5329387 ],
        [-0.13148117, -0.22390294],
        [-0.5452256 ,  0.9112253 ]], dtype=float32),
 array([[-0.58784664,  0.3195368 , -0.45753336, -0.15700167, -0.6013385 ,
          0.49908793, -0.26111907,  0.17951947]], dtype=float32),
 array([[ 0.757552  , -1.1105257 ],
        [ 0.48122764,  0.08933914]], dtype=float32)]

In [82]:
lb.set_weights(weights)

In [83]:
lb.trainable_weights

[<tf.Variable '/attn/proj:0' shape=(4, 2) dtype=float32, numpy=
 array([[-0.6378875 , -0.45413303],
        [ 0.5230794 ,  0.88192725],
        [-0.21150255,  0.99610925],
        [ 0.17696261,  0.57753515]], dtype=float32)>,
 <tf.Variable '/attn/qkv:0' shape=(1, 8) dtype=float32, numpy=
 array([[ 0.2691133 ,  0.68519545,  0.43147898,  0.2203914 , -0.08316743,
          0.05300945, -0.5406209 , -0.5833623 ]], dtype=float32)>]

In [84]:
lb.unpack()

In [85]:
lb.trainable_weights

[<tf.Variable '/attn/proj:0' shape=(4, 2) dtype=float32, numpy=
 array([[-0.6378875 , -0.45413303],
        [ 0.5230794 ,  0.88192725],
        [-0.21150255,  0.99610925],
        [ 0.17696261,  0.57753515]], dtype=float32)>,
 <tf.Variable '/attn/qkv:0' shape=(1, 8) dtype=float32, numpy=
 array([[ 0.2691133 ,  0.68519545,  0.43147898,  0.2203914 , -0.08316743,
          0.05300945, -0.5406209 , -0.5833623 ]], dtype=float32)>,
 <tf.Variable 'Variable:0' shape=(2, 2) dtype=float32, numpy=
 array([[-0.9065486 , -1.1240162 ],
        [ 1.2227877 ,  0.37457538]], dtype=float32)>]

In [86]:
lb.pack()

In [87]:
lb.trainable_weights

[<tf.Variable '/attn/proj:0' shape=(4, 2) dtype=float32, numpy=
 array([[-0.6378875 , -0.45413303],
        [ 0.5230794 ,  0.88192725],
        [-0.21150255,  0.99610925],
        [ 0.17696261,  0.57753515]], dtype=float32)>,
 <tf.Variable '/attn/qkv:0' shape=(1, 8) dtype=float32, numpy=
 array([[ 0.2691133 ,  0.68519545,  0.43147898,  0.2203914 , -0.08316743,
          0.05300945, -0.5406209 , -0.5833623 ]], dtype=float32)>]

In [61]:
from tensorflow.keras.layers import Dropout, Dense

class PCALunchboxMHSA(tf.keras.layers.Layer):
    def __init__(self,
                 dim,
                 num_heads,
                 lunchbox_dim,
                 packed=False,
                 qkv_bias=True,
                 qk_scale=None,
                 proj_drop=0.,
                 prefix=''):
        """
        Dot product self-attention where the K table is a matrix of learnable
        parameters.  We use the 'Lunchbox' metaphor within this layer where the 
        lunchbox is K, and the savory lunchtime treats are the columns of K.
        Packing is intended to refer to a stage of transfer learning wherein the
        weights that form K are learned from an external task.  The lunchbox
        is considered 'packed' after learning on the external task, and 'unpacked'
        during training on the external task.  Call .pack() to freeze K, call
        .unpack() to unfreeze them.

        :param dim: embedding dimension for k, q, v
        :param num_heads: number of attention heads
        :param packed: whether or not the lunchbox should be considered packed (True -> k is not trainable)
        :param qkv_bias: whether or not to use bias in the initial projection to dim
        :param qk scale: scale for rescaling as show in https://arxiv.org/abs/1706.03762, defaults to dim ** -0.5
        :param proj_drop: dropout rate for output
        :param prefix: name of this layer
        """
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.lunchbox_dim = lunchbox_dim
        self.qkv_bias = qkv_bias

        self.scale = qk_scale or max(dim, lunchbox_dim) ** -0.5
        self.prefix = prefix

        self.qkv = None

        self.proj = None
        self.proj_drop = Dropout(proj_drop)
        self.k = None
        self.built = False
        self.packed = packed

    def build(self, input_shape):
        self.k = self.add_weight(f'{self.prefix}/attn/lunchbox',
                                 shape=(self.dim, self.lunchbox_dim),
                                 initializer=tf.initializers.GlorotUniform(), 
                                 trainable=self.packed)
        
        self.proj = self.add_weight(f'{self.prefix}/attn/proj',
                                 shape=(self.num_heads * self.dim, self.dim),
                                 initializer=tf.initializers.GlorotUniform(), 
                                 trainable=True)

        self.qkv = self.add_weight(f'{self.prefix}/attn/qkv',
                          shape=(input_shape[-1], self.dim * 2 * self.num_heads),
                          initializer=tf.initializers.GlorotUniform(), 
                          trainable=True)

        self.built = True


    def pack(self):
      # make the table untrainable
      self.k = tf.Variable(self.k, trainable=False)

    def unpack(self):
      # make the table trainable
      self.k = tf.Variable(self.k, trainable=True)

    def call(self, x):

        B_, N, C = x.get_shape().as_list()

        # x = tf.reshape(x, (B_, N, C))
        # (b, n, ch), (ch, dim) -> (b, n, dim * 2 * h)
        x = tf.einsum('bnc,cd->bnd', x, self.qkv)
        x = tf.transpose(x, perm=[0, 2, 1])

        x = tf.reshape(x, (-1, 2, self.num_heads, self.dim, N))
        qv = tf.transpose(x, perm=[1, 0, 2, 3, 4])

        q, v = qv[0], qv[1]

        attn = tf.einsum('bikj,kr->birk', q, self.k)
        attn = attn * self.scale
        tf.nn.softmax(attn, -1)

        res = tf.einsum('birk,bikn->bink', attn, v)

        x = tf.reshape(res, (-1, N, self.num_heads * self.dim))

        x = self.proj(x)
        # (b, n, dim * n_h), (ch, dim) -> (b, n, dim)
        x = tf.einsum('bnc,cd->bnd', x, self.proj)

        x = self.proj_drop(x)

        return x

    def get_config(self):
        return {"k": self.k.numpy()}

In [1]:
from support.util import Config, Experiment

from trainable.models.vit import build_focal_LAXNet, build_basic_lunchbox
from trainable.models.cnn import build_basic_convnextv2, build_basic_cnn


from data.datasets.image_classification import deep_weeds, cats_dogs, dot_dataset, citrus_leaves
from optimization.data_augmentation.msda import mixup_dset, blended_dset
from optimization.data_augmentation.ssda import add_gaussian_noise_dset, custom_rand_augment_dset, foff_dset

from tensorflow.keras.callbacks import LearningRateScheduler
from optimization.callbacks import EarlyStoppingDifference

from optimization.training_loops.supervised import keras_supervised
from optimization.schedules import bleed_out
"""
hardware_params must include:

    'n_gpu': uint
    'n_cpu': uint
    'node': str
    'partition': str
    'time': str (we will just write this to the file)
    'memory': uint
    'distributed': bool
"""
hardware_params = {
    'name': 'hparam',
    'n_gpu': 4,
    'n_cpu': 16,
    'partition': 'ai2es',
    'nodelist': ['c733'],
    'time': '96:00:00',
    'memory': 16384,
    # The %04a is translated into a 4-digit number that encodes the SLURM_ARRAY_TASK_ID
    'stdout_path': '/scratch/jroth/supercomputer/text_outputs/exp%01a_stdout_%A.txt',
    'stderr_path': '/scratch/jroth/supercomputer/text_outputs/exp%01a_stderr_%A.txt',
    'email': 'jay.c.rothenberger@ou.edu',
    'dir': '/scratch/jroth/AI2ES-DL/',
    'array': '[1]',
    'results_dir': 'results'
}
"""
network_params must include:
    
    'network_fn': network building function
    'network_args': arguments to pass to network building function
        network_args must include:
            'lrate': float
    'hyperband': bool
"""
image_size = (128, 128, 3)

network_params = {
    'network_fn': build_basic_lunchbox,
    'network_args': {
        'lrate': 5e-4,
        'n_classes': 2,
        'iterations': 6,
        'conv_filters': '[32, 48, 64, 96]',
        'conv_size': '[3]',
        'dense_layers': '[16]',
        'learning_rate': [5e-4],
        'image_size': image_size,
        'l1': None,
        'l2': None,
        'alpha': [1, 2**(-10)],
        'beta': [2**(-7)],
        'noise_level': 0.005,
        'depth': 4,
    },
    'hyperband': False
}

"""
experiment_params must include:
    
    'seed': random seed for computation
    'steps_per_epoch': uint
    'validation_steps': uint
    'patience': uint
    'min_delta': float
    'epochs': uint
    'nogo': bool
"""


experiment_params = {
    'seed': 42,
    'steps_per_epoch': 512,
    'validation_steps': 256,
    'patience': 3,
    'min_delta': 0.0,
    'epochs': 64,
    'nogo': False,
}
"""
dataset_params must include:
    'dset_fn': dataset loading function
    'dset_args': arguments for dataset loading function
    'cache': str or bool
    'batch': uint
    'prefetch': uint
    'shuffle': bool
    'augs': iterable of data augmentation functions
"""
dataset_params = {
    'dset_fn': cats_dogs,
    'dset_args': {
        'image_size': image_size[:-1],
        'path': '../data/'
    },
    'cache': False,
    'cache_to_lscratch': False,
    'batch': 64,
    'prefetch': 4,
    'shuffle': True,
    'augs': []
}

optimization_params = {
    'callbacks': [
        # EarlyStoppingDifference(patience=experiment_params['patience'],
        #                        restore_best_weights=True,
        #                        min_delta=experiment_params['min_delta'],
        #                        metric_0='val_clam_categorical_accuracy',
        #                        metric_1='val_clam_1_categorical_accuracy',
        #                        n_classes=2),

        LearningRateScheduler(bleed_out(network_params['network_args']['learning_rate'])),
        # LossWeightScheduler(loss_weight_schedule)
    ],
    'training_loop': keras_supervised
}

config = Config(hardware_params, network_params, dataset_params, experiment_params, optimization_params)

exp = Experiment(config)
exp.run_array(0)


ModuleNotFoundError: No module named 'support'