### Train partial models and combine them by freezing layers partially and then adding them together.

In [57]:
import tensorflow as tf
import numpy as np
from itertools import product

class PartialFreeze(tf.keras.layers.Layer):
    def __init__(self, axis=None, cut=None):
        """
        Partially freezes the layer by eliminating the gradients from certain layers achieved by multiplying outputs by zero.
        
        :param axis: axis or axes along which to weight outputs by zero
        :param cut: tuple (start, end) or tuple of tuples for each axis indicies to mask
        """
        super().__init__()
        self.axis = axis
        self.cut = cut
        self.mask = None
        
        if isinstance(self.cut[0], int):
            self.cut = (cut,)
            self.axis = (axis,)

    def build(self, input_shape):
        # build the mask along the axes
        shape = input_shape[1:]
        print(shape)
        self.mask = np.zeros(input_shape[1:])
        print(self.mask.shape)
        
        cuts = [slice(*cut) for cut in self.cut]
        axes = list(self.axis)
        
        axes_cuts = {axis: cut for axis, cut in zip(axes, cuts)}
        
        slice_list = [range(shape[i]) if i not in axes else range(shape[i])[axes_cuts[i]] for i in range(len(input_shape[1:]))]
        slice_list = [list(i) for i in slice_list]

        for indicies in product(*slice_list):
            self.mask.__setitem__(indicies, 1.0)

    def call(self, x, training=None):
        if training:
            # need to incur the performance benefit of a partial freeze
            x = x * self.mask
        else:
            # do not need to incur the performance benefit of a partial freeze (I neva freeze)
            pass

        return x

    def get_config(self):
        return {"axis": self.axis, "cut": self.cut}

In [60]:
from tensorflow.keras.layers import *

partial = PartialFreeze(cut=((0, 4), (0, 3)), axis=(0, 1))

x = Input((8, 8))
x = partial(x)

print(partial.mask, partial.mask.shape)

(8, 8)
(8, 8)
[[1. 1. 1. 0. 0. 0. 0. 0.]
 [1. 1. 1. 0. 0. 0. 0. 0.]
 [1. 1. 1. 0. 0. 0. 0. 0.]
 [1. 1. 1. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0.]] (8, 8)
