In [None]:
import numpy as np
import time, os, sys, gc
from IPython.display import clear_output

import matplotlib.pyplot as plt
from matplotlib.image import imread
%matplotlib inline
plt.style.use('dark_background')

In [None]:
import tensorflow as tf
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus: 
    tf.config.experimental.set_memory_growth(gpu, True)
    # tf.config.experimental.set_virtual_device_configuration(gpu, [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=0.95)])


from keras import layers, models, backend
import keras.backend as K

from keras.callbacks import ReduceLROnPlateau, EarlyStopping, ModelCheckpoint
from keras.utils import get_custom_objects

# import tensorflow_addons as tfa
# import tensorflow_probability as tfp

gc.collect()
tf.keras.backend.clear_session()

# from keras import mixed_precision
# mixed_precision.set_global_policy('mixed_float16')

In [None]:
from custom_layers import *
from custom_encodingdecoding import *

In [None]:
# - One Model To Rule Them All, ... And Within Its Latents Bind Them - Recurisive Recontextualization and Dynamic Distillation (R2D2)
#
# Project Description:
# A dynamically recursive and scalable multi-modal neural network architecture.
# Dynamically recursive in the sense that the model operates off of a singular recursive block, 
# wherin a hybrid processing block creates an intermediate view from the data patches, combined with a view from the working context,
# to calculate updates to the working context,  then block weights are updated, and we use intermediates and the context to predict 
# an update to the data patches, and the process repeats.
# a positional embedding is used to add positional information to the data patches, but additionally we use a type embedding to 
# help the model differentiate between types of data patches, and additionally to identify target data.

In [None]:
class R2D2(tf.keras.models.Model):
    def __init__(self, data_size, context_size, head_width, weight_gen_size, modes, depth, activation, encoder_decoder_params):
        super(R2D2, self).__init__()
        self.data_size = data_size
        self.context_size = context_size
        self.head_width = head_width
        self.weight_gen_size = weight_gen_size
        self.modes = modes
        self.depth = depth
        self.act = activation
        self.encoder_decoder_params = encoder_decoder_params

        self.data_encoders, self.data_decoders = get_encoders_and_decoders(self.data_size, self.modes, **self.encoder_decoder_params)

        self.num_heads = self.context_size // self.head_width

        self.att_kernel_size = 7 # should possibly increase depending on internal sizes
        

        # positional encoding stack
        self.wpos = []
        scale = 1
        while self.num_heads // scale >= 2:
            self.wpos.append(self.add_weight(name='widthwise_positional_encoding_h' + str(self.num_heads // scale), 
                                             shape=(1, self.head_width * (self.num_heads // scale)), 
                                             initializer='random_normal', trainable=True))
            scale *= 2
        self.wpos.append(self.add_weight(name='widthwise_positional_encoding_h1', shape=(1, self.head_width), initializer='random_normal', 
                                      trainable=True))
        
        # locality embeddings
        self.data_locality_dense = layers.Dense(64, activation=self.act, name='data_locality_dense')
        self.data_locality_kernel = layers.Dense(self.att_kernel_size, activation='linear', name='data_locality_kernel', use_bias=False)
        
        # type embedding
        self.type_embedding = layers.Embedding(len(self.modes), self.context_size, name='type_embedding')

        # context initializer
        self.context_initializer = self.add_weight(name='context_initializer', shape=(1, self.context_size), initializer='random_normal', trainable=True)

        # block weights

        self.weight_shapes = {}

        self.weight_shapes['data_view'] = (self.data_size, self.data_size)
        self.weight_shapes['context_view'] = (self.context_size, self.data_size)

        self.weight_shapes['context_update'] = (self.data_size, self.context_size)
        # self.weight_shapes['context_update_importance'] = (self.data_size, 2)

        self.weight_shapes['context_view_2'] = (self.context_size, self.data_size)
        self.weight_shapes['data_update'] = (self.data_size, self.data_size)
        # self.weight_shapes['data_update_importance'] = (self.data_size, 2)

        # weight generation

        # self.list_of_weight_shapes = [tf.cast(w, tf.int32) for w in list(self.weight_shapes.items())]
        self.list_of_weight_shapes = [
            (self.data_size, self.data_size),
            (self.context_size, self.data_size),
            (self.data_size, self.context_size),
            (self.context_size, self.data_size),
            (self.data_size, self.data_size)
        ]
        self.weight_lens = [tf.math.reduce_prod(w) for w in self.list_of_weight_shapes]
        total_weights = tf.math.reduce_sum(self.weight_lens)

        self.weight_gen_comp = layers.Dense(self.weight_gen_size, activation=self.act, name='weight_gen_comp')
        self.weight_gen = layers.Dense(total_weights, activation='linear', name='weight_gen')

        # weight norms
        self.weight_norms = [layers.LayerNormalization(name='weight_norm_' + str(i)) for i in range(len(self.list_of_weight_shapes))]

        self.inverse_weight_proc = layers.Dense(self.context_size, activation=self.act, name='inverse_weight_proc')

    # aggregate the positional encodings
    def get_positional_encoding(self):
        positional_encoding = self.wpos[0]
        for i in range(1, len(self.wpos)):
            wpos = self.wpos[i]
            # repeat the positional encoding as many times as needed
            num_repeats = tf.cast(tf.math.ceil(self.context_size / tf.shape(wpos)[1]), tf.float32)
            num_repeats_int = tf.cast(num_repeats, tf.int32)
            wpos = tf.repeat(wpos, num_repeats_int, axis=0)
            # multiply by linear decay
            # print(wpos.shape, (tf.cast(tf.range(num_repeats_int, 0, -1), tf.float32) / num_repeats).shape)
            wpos = wpos * tf.expand_dims(tf.cast(tf.range(num_repeats_int, 0, -1), tf.float32) / num_repeats, axis=-1)
            # unstack the repeats
            wpos = tf.reshape(wpos, (-1,))
            # truncate to the correct length and add to the positional encoding
            positional_encoding = positional_encoding + wpos[:self.context_size]
        return positional_encoding
    
    def get_locality_embedding(self, patch_dims):
        x = self.data_locality_dense(patch_dims)
        k = self.data_locality_kernel(x)
        # reshape k into a kernel
        k = tf.reshape(k, (self.att_kernel_size, 1, -1))
        # apply the kernel to the positional encoding
        pos_enc = tf.expand_dims(self.get_positional_encoding(), axis=-1)
        output = tf.nn.conv1d(pos_enc, k, stride=1, padding='SAME')
        # transpose the output to the correct shape
        output = tf.transpose(output, (0, 2, 1))
        return output[0] # should fix batch size elsewhere
    
    # weight generation
    def get_weights(self, context):
        # normalize the context (need to handle zero case)
        # context = (context - tf.math.reduce_mean(context, axis=0, keepdims=True)) / tf.math.reduce_std(context, axis=0, keepdims=True)
        intermediate = self.weight_gen_comp(context)
        weights = self.weight_gen(intermediate)

        # split the weights
        weights = tf.split(weights, self.weight_lens, axis=-1)
        weights = [tf.reshape(w, s) for w, s in zip(weights, self.list_of_weight_shapes)]
        weights = [self.weight_norms[i](w) for i, w in enumerate(weights)]

        # return as a dictionary with names
        return dict(zip(self.weight_shapes.keys(), weights))
    
    def change_depth(self, new_depth):
        self.depth = new_depth

    def process_recursive_step(self, X, E, YS, Y, context, training=False, inverse=False):
        # X is the input data (batch size, data size)
        # E is the locality embedding (batch size, context size)
        # Y is the y selection (batch size, 1)
        # context is the context (batch size, context size)

        # generate weights
        if inverse:
            wcontext = self.inverse_weight_proc(context)
        else:
            wcontext = context
        weights = self.get_weights(wcontext)

        # initial data view
        xX = X
        cC = context

        # apply weights
        xV = tf.matmul(xX + tf.gather(E, tf.range(self.data_size), axis=-1), weights['data_view'])
        cV = tf.matmul(context + E, weights['context_view'])

        V = xV + cV
        V = self.act(V)

        cU = tf.matmul(V, weights['context_update'])
        cU = tf.math.reduce_mean(cU, axis=0, keepdims=True)
        # normalize the context update
        # cU = (cU - tf.math.reduce_mean(cU, axis=-1, keepdims=True)) / (tf.math.reduce_std(cU, axis=-1, keepdims=True) + 1e-3)
        cC = cC + cU
        # normalize the context
        cC = (cC - tf.math.reduce_mean(cC, axis=-1, keepdims=True)) / (tf.math.reduce_std(cC, axis=-1, keepdims=True) + 1e-3)

        ydx = tf.reshape(tf.where(YS > 0.0), (-1,))
        cV2 = tf.matmul(cC + (tf.gather(E, ydx, axis=0)), weights['context_view_2'])
        V2 = tf.gather(V, ydx, axis=0) + cV2
        V2 = self.act(V2)

        # apply data update
        dU = tf.matmul(V2, weights['data_update'])
        # normalize the data update
        # dU = (dU - tf.math.reduce_mean(dU, axis=-1, keepdims=True)) / (tf.math.reduce_std(dU, axis=-1, keepdims=True) + 1e-3)
        uX = tf.gather(xX, ydx, axis=0) + dU
        # normalize the data
        uX = (uX - tf.math.reduce_mean(uX, axis=-1, keepdims=True)) / (tf.math.reduce_std(uX, axis=-1, keepdims=True) + 1e-3)

        # scatter the updated data
        # print(X.shape, ydx.shape, uX.shape)
        X = tf.tensor_scatter_nd_update(X, tf.expand_dims(ydx, axis=-1), uX)

        loss = None
        if training:
            loss = tf.math.reduce_mean(tf.math.square(uX - Y))

        return X, context, loss
    
    def call(self, inputs, training=False):
        x, y, mode = inputs
        # x is input data (batch size, z length, *(data size, differs by mode))
        # y represents what z steps are being predicted (batch size, z) where each z scales from 0 to 1, 0 if not predicted, 1 if predicted
        # if training, we use y to mask/corrupt corresponding x data for autorregressive training,
        # otherwise we use it to determine what z steps to predict
        # mode determines how z steps are encoded and decoded. (batch size, z)

        # assert len(x) == len(y) == len(mode)

        # since batches and z steps are generally ragged, we're just going to encode them separately for now
        bs = len(x) # tf.shape(x)[0]
        
        FX = []
        EE = []
        PS = []
        YS = []
        dec_loss = 0.0
        for b in range(bs):
            bx = x[b]
            by = y[b]
            bmode = mode[b]
            fx = []
            ee = []
            ps = []
            ys = []
            for z in range(len(bx)): # range(tf.shape(bx)[0]):
                ex, p_s, ped = self.data_encoders[bmode[z]]([bx[z], z - len(bx)]) # tf.shape(bx)[0]])
                # normalize ex
                # ex = (ex - tf.math.reduce_mean(ex, axis=-1, keepdims=True)) / (tf.math.reduce_std(ex, axis=-1, keepdims=True) + 1e-3)
                if training: # decode the data and compare to the original
                    dex = self.data_decoders[bmode[z]]([ex, p_s])
                    dec_loss += tf.math.reduce_mean(tf.math.square(dex - bx[z])) * (1.0 / len(bx))

                # use ped to get the locality embedding
                locality_embedding = self.get_locality_embedding(ped)

                # we also need to copy by[z] by len of ex for selecting updates later
                byz = tf.repeat(tf.expand_dims(by[z], axis=0), tf.shape(ex)[0], axis=0)

                # get type embedding, check bmode[z] against self.modes
                ste = tf.where(tf.equal(bmode[z], self.modes))[0][0]
                ste = self.type_embedding(ste)
                # expand ste to the length of ex
                ste = tf.repeat(tf.expand_dims(ste, axis=0), tf.shape(ex)[0], axis=0)

                fx.append(ex)
                ee.append(locality_embedding + ste)
                ps.append(p_s)
                ys.append(byz)

            FX.append(tf.concat(fx, axis=0))
            EE.append(tf.concat(ee, axis=0))
            PS.append(tf.concat(ps, axis=0))
            YS.append(tf.concat(ys, axis=0))

        FX = tf.stack(FX, axis=0)
        EE = tf.stack(EE, axis=0)
        PS = tf.stack(PS, axis=0)
        YS = tf.stack(YS, axis=0)

        FY = None
        if training:
            FY = tf.gather(FX, tf.reshape(tf.where(YS > 0.0), (-1,)), axis=0)
        mYS = tf.expand_dims(YS, axis=-1)
        FX = (FX * (1.0 - mYS)) + (tf.zeros(tf.shape(FX), dtype=tf.float32) * (mYS))

        # FX is the encoded data (batch size * (*variable), data size)
        # LE is the locality embedding (batch size * (*variable), context size)
        # PS is the positional encoding (batch size * zsteps, (*variable))
        # YS is the y selection (batch size * (*variable), 1)

        # initialize the context
        context = self.context_initializer

        gtloss = 0.0
        giloss = 0.0
        for i in range(self.depth):
            # apply the recursive step
            nX, nContext, nloss = self.process_recursive_step(FX, EE, YS, FY, context, training=training)
            if training:
                # inverse step
                iFY = tf.gather(FX, tf.reshape(tf.where(YS > 0.0), (-1,)), axis=0)
                iX, iContext, iloss = self.process_recursive_step(nX, EE, YS, iFY, nContext * -1, training=training, inverse=True)

                if i >= 2: # min recursion for loss
                    gtloss += nloss * (1.0 / (self.depth - 2))
                if i < self.depth - 2:
                    giloss += iloss * (1.0 / (self.depth - 2))
            FX = nX
            context = nContext


        if training:
            gloss = gtloss + giloss + dec_loss
            return FX, [gloss, gtloss, giloss, dec_loss]
        else:
            return FX, [None, None, None, None]

In [None]:
# inds = 2560
dpth = 8
test_model = R2D2(
    data_size=1024, 
    context_size=4096, 
    head_width=128, # is this even used?
    weight_gen_size=dpth, 
    modes=['image', 'imagenet1k_classification'], 
    depth=dpth,
    activation=aptx,
    encoder_decoder_params={
        'patch_size':16,
        'conv_mult':1.0,
        'vocab_length':1024,
        'char_embed_size':16,
    }
)

In [None]:
# import tensorflow_datasets as tfds
from datasets import load_dataset

In [None]:
data_cache_dir='S:/Datasets'

In [None]:
imagenet1k = load_dataset('imagenet-1k', cache_dir=data_cache_dir, split='train')

In [None]:
len(imagenet1k)

In [None]:
test_image = tf.expand_dims(tf.image.resize_with_crop_or_pad(tf.keras.utils.img_to_array(imagenet1k[0]['image']) / 255.0, 256, 256), axis=0)
test_class = tf.one_hot(imagenet1k[0]['label'], 1000)

In [None]:
test_model(([[test_image, test_class]], [[0., 1.]], [['image', 'imagenet1k_classification']]), training=True)

In [None]:
test_model.summary()

In [None]:
optimizer = tf.keras.optimizers.Adam(learning_rate=0.0003, beta_1=0.9, beta_2=0.95, epsilon=1e-07, amsgrad=False, name='Adam', clipnorm=0.1)
# optimizer = tf.keras.optimizers.SGD(learning_rate=0.01, momentum=0.9, nesterov=True, name='SGD', clipnorm=0.1)
# loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)
loss_fn = lambda x: x

batch_size = 1

In [None]:
epochs = 1
for epoch in range(epochs):
    print("\nStart of epoch %d" % (epoch,))

    for step in range(len(imagenet1k)):
        x_batch_train = tf.expand_dims(tf.keras.utils.img_to_array(imagenet1k[step]['image']) / 255.0, axis=0)
        im_x = x_batch_train.shape[-3]
        im_y = x_batch_train.shape[-2]
        if im_x > im_y:
            x_batch_train = tf.image.resize(x_batch_train, (int((im_x / im_y) * 256), 256))
        else:
            x_batch_train = tf.image.resize(x_batch_train, (256, int((im_y / im_x) * 256)))
        x_batch_train = tf.image.resize_with_crop_or_pad(x_batch_train, 256, 256)
        y_batch_train = tf.one_hot(tf.convert_to_tensor(imagenet1k[step]['label']), 1000)


        with tf.GradientTape() as tape:
            logits, losses = test_model(([[x_batch_train, y_batch_train]], [[0.0, 1.0]], [['image', 'imagenet1k_classification']]), training=True) 

            loss_value = losses[0]

            cat_preds = test_model.data_decoders['imagenet1k_classification']([logits[:,-1], [1,]])[0]
            cat_loss = tf.keras.losses.categorical_crossentropy(y_batch_train, cat_preds)
            loss_value = (cat_loss * 0.1) + losses[1] + losses[2] + (losses[3] * 0.1)

        grads = tape.gradient(loss_value, test_model.trainable_weights)
        optimizer.apply_gradients(zip(grads, test_model.trainable_weights))

        print(
            "Training losses (idv) at step %d:"
            % (step,), [np.round(float(loss_value), 8)] + 
            [np.round(float(loss), 8) for loss in losses[1:]] + 
            [np.round(float(cat_loss), 8)], "Seen: %s samples" % ((step + 1) * batch_size)
        )

In [None]:
imagenet1k_valid = load_dataset('imagenet-1k', cache_dir=data_cache_dir, split='validation')

In [None]:
# validate the model

agg_loss = 0.0
agg_acc = 0.0

for step in range(len(imagenet1k_valid)):
    istep = step
    x_batch = tf.expand_dims(tf.keras.utils.img_to_array(imagenet1k_valid[istep]['image']) / 255.0, axis=0)

    im_x = x_batch.shape[-3]
    im_y = x_batch.shape[-2]
    if im_x > im_y:
        x_batch = tf.image.resize(x_batch, (int((im_x / im_y) * 256), 256))
    else:
        x_batch = tf.image.resize(x_batch, (256, int((im_y / im_x) * 256)))
    x_batch = tf.image.resize_with_crop_or_pad(x_batch, 256, 256)
    y_batch = tf.one_hot(tf.convert_to_tensor(imagenet1k_valid[istep]['label']), 1000)

    y_zero = tf.zeros_like(y_batch, dtype=tf.float32) + (1.0 / 1000.0)
    logits, losses = test_model(([[x_batch, y_zero]], [[0.0, 1.0]], [['image', 'imagenet1k_classification']]), training=True)  # Logits for this minibatch)

    # use the decoder in the model to decode for classification
    logits = test_model.data_decoders['imagenet1k_classification']([logits[:,-1], [1,]])

    loss_value = tf.keras.losses.categorical_crossentropy(tf.expand_dims(y_batch, axis=0), logits)
    accuracy = tf.keras.metrics.categorical_accuracy(tf.expand_dims(y_batch, axis=0), logits)

    agg_loss += loss_value
    agg_acc += accuracy

    if step % 100 == 0:
        print("step: %d, loss: %.4f, accuracy: %.4f" % (step, agg_loss / (step + 1), agg_acc / (step + 1)))