# load and pack using the schema

In [4]:
import numpy as np
from tensorflow.lite.tools import flatbuffer_utils
from tensorflow.lite.python import schema_py_generated as schema_fb

def load(path):
    return flatbuffer_utils.read_model_with_mutable_tensors(path)

def save(model , path):
    flatbuffer_utils.write_model(model, path)


def modify_conv2d_to_custom(model):
    # Create a new custom op code
    custom_op_code = schema_fb.OperatorCodeT()
    custom_op_code.customCode = "CONV_XOR"
    custom_op_code.builtinCode = schema_fb.BuiltinOperator.CUSTOM
    model.operatorCodes.append(custom_op_code)
    custom_op_index = len(model.operatorCodes) - 1

    for subgraph in model.subgraphs:
        for i ,operator in enumerate(subgraph.operators):
            op_code = model.operatorCodes[operator.opcodeIndex]
            if op_code.builtinCode == schema_fb.BuiltinOperator.CONV_2D and i>15:
                # Modify the operator to use your custom op
                operator.opcodeIndex = custom_op_index
                
                # Convert builtin options to custom options
                #to do

    return model



def pack_conv_weights(weights):
    """
    Pack a 4D numpy array of -1 and 1 values along the input channel dimension.
    
    Args:
    weights (numpy.ndarray): 4D array of shape (height, width, in_channels, out_channels)
    
    Returns:
    numpy.ndarray: Packed array with reduced in_channels dimension
    """
    assert set(np.unique(weights)) <= {-1, 1}, "Input array should only contain -1 and 1"
    
    out_channels ,height, width, in_channels= weights.shape
    # Calculate the number of packed channels
    packed_channels = (in_channels + 31) // 32
    
    # Initialize the packed array
    packed_weights = np.zeros((out_channels,height, width, packed_channels), dtype=np.int32)
    
    for h in range(height):
        for w in range(width):
            for o in range(out_channels):
                for p in range(packed_channels):
                    for i in range(32):
                        if p*32  + i < in_channels:
                            # Pack 32 values into a single integer
                            packed_weights[o, h, w,p] |= (
                                (weights[o, h, w,p*32 + i] > 0).astype(int) << i
                            )
    
    return packed_weights

def modify_conv_weights(model):
    b=1
    i=1
    for subgraph in model.subgraphs:
        for operator in subgraph.operators:
            i+=1
            op_code = model.operatorCodes[operator.opcodeIndex]
            #if op_code.customCode == "CONV_XOR":
            if op_code.builtinCode == schema_fb.BuiltinOperator.CONV_2D and i>15:
    
                # Assuming weights are the second input to the conv operator
                weight_tensor_index = operator.inputs[1]
                weight_tensor = subgraph.tensors[weight_tensor_index]
                new_name = f'conv_xor_w_{b}'
                weight_tensor.name=new_name
                b+=1
               # print(new_name.encode('utf-8'))
                buffer = model.buffers[weight_tensor.buffer]
                weights = np.frombuffer(buffer.data, dtype=np.float32)
                weights = weights.reshape(weight_tensor.shape)
               # print(type(weight_tensor.shape))
                weights=pack_conv_weights(weights)
                #print(type(np.array(weights.shape)))
                #weight_tensor.shape = np.array(weights.shape)

                # # Modify data type
                #weight_tensor.type = schema_fb.TensorType.INT32
                buffer.data = weights.tobytes()

    return model

    
def process_model(path):
    model=load(path)
    #model=modify_conv2d_to_custom(model)
    model=modify_conv_weights(model)
    save(model,"dfsmn_xor_small.tflite")
  
        



In [5]:
process_model('dfsmn_model.tflite')

In [10]:
print(packed_weights[0, 0, 0, 32:64])
print(weights[0 ,0, 0, 32:64])

[-1307692680           0           0           0           0           0
           0           0           0           0           0           0
           0           0           0           0           0           0
           0           0           0           0           0           0
           0           0           0           0           0           0
           0           0]
[-1 -1 -1  1  1  1  1 -1  1 -1 -1 -1  1  1 -1 -1 -1  1  1  1 -1 -1 -1 -1
 -1  1 -1 -1  1  1 -1  1]


In [9]:
import numpy as np


def unpack_conv_weights(packed_weights, original_shape):
    """
    Unpack a packed weight array back to its original shape.
    
    Args:
    packed_weights (numpy.ndarray): Packed weight array.
    original_shape (tuple): Original shape of the unpacked weights.
    
    Returns:
    numpy.ndarray: Unpacked weight array.
    """
    in_channels, height, width, out_channels = original_shape
    unpacked_weights = np.zeros(original_shape, dtype=int)
    
    packed_channels = (in_channels + 31) // 32
    
    for h in range(height):
        for w in range(width):
            for o in range(out_channels):
                for p in range(packed_channels):
                    for i in range(32):
                        if p * 32 + i < in_channels:
                            unpacked_weights[p * 32 + i, h, w, o] = (
                                ((packed_weights[p, h, w, o] >> i) & 1) * 2 - 1
                            )
    return unpacked_weights

# Test data: Randomly generate -1 and 1 values
in_channels = 1  # Number of input channels (not divisible by 32 to test edge cases)
height = 1        # Height of the filter
width = 1         # Width of the filter
out_channels = 64   # Number of output channels

# Generate random weights with -1 and 1 values
weights = np.random.choice([-1, 1], size=(in_channels, height, width, out_channels))

# Print original weights for debugging purposes
print("Original Weights:")
print(weights)

# Pack weights using the function
packed_weights = pack_conv_weights(weights)

# Print packed weights for debugging purposes
print("\nPacked Weights:")
print(packed_weights)

# Unpack weights and compare with original weights
unpacked_weights = unpack_conv_weights(packed_weights, weights.shape)

print("\nUnpacked Weights:")
print(unpacked_weights)

# Check if unpacked weights match the original weights
if np.array_equal(unpacked_weights, weights):
    print("\nPacking and unpacking are correct!")
else:
    print("\nThere is an issue with packing or unpacking.")


Original Weights:
[[[[-1  1  1 -1  1  1 -1 -1  1  1  1 -1 -1 -1  1  1  1 -1  1  1  1 -1
    -1  1 -1 -1 -1 -1  1 -1  1 -1 -1 -1 -1  1  1  1  1 -1  1 -1 -1 -1
     1  1 -1 -1 -1  1  1  1 -1 -1 -1 -1 -1  1 -1 -1  1  1 -1  1]]]]

Packed Weights:
[[[[ 1352517430           0           0           0           0
              0           0           0           0           0
              0           0           0           0           0
              0           0           0           0           0
              0           0           0           0           0
              0           0           0           0           0
              0           0 -1307692680           0           0
              0           0           0           0           0
              0           0           0           0           0
              0           0           0           0           0
              0           0           0           0           0
              0           0           0           0  