<a href="https://colab.research.google.com/github/AggelosAr/3d_CONV/blob/main/3dCONV.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:

from math import log10, sqrt
import numpy as np
import tensorflow as tf
from tensorflow.keras import mixed_precision
import time
import timeit
import os
from matplotlib import pyplot as plt


from google.colab import drive
drive.mount('/content/drive')

device_name = tf.test.gpu_device_name()
if device_name != '/device:GPU:0':
  raise SystemError('GPU device not found')
print('Found GPU at: {}'.format(device_name))


    
mixed_precision.set_global_policy('mixed_float16')#mixed_float16
tf.config.experimental.tensor_float_32_execution_enabled()
os.environ['TF_GPU_THREAD_MODE']='gpu_private'
os.environ['TF_GPU_THREAD_COUNT']='1'
os.environ['xla_gpu_autotune_level']='4'


In [None]:


def PSNR(original, compressed):
    mse = np.mean((original - compressed) ** 2)
    if(mse == 0):  # MSE is zero means no noise is present in the signal .
                  # Therefore PSNR have no importance.
        return 100
    max_pixel = 255.0
    psnr = 20 * log10(max_pixel / sqrt(mse))
    return psnr


class myConv():
#class myConv(tf.Module):
#class myConv(tf.keras.layers.Layer):
    
    def __init__(self , no_kernels, input_shape):#input_shape
        super(myConv, self).__init__()
        
        self.A = A                     # ***** NEEDS TO BE PRECOMPUTED *****
        self.A_f_T = A_f_T             # ***** NEEDS TO BE PRECOMPUTED *****
        self.C_T = C_T                 # ***** NEEDS TO BE PRECOMPUTED *****
        
        self.kern_num = no_kernels
   
    #def build(self, input_shape):# TRAINABLE
        
        self.paddings_image = tf.constant([[0, 0], [0, 2], [0, 0], [0, 2], [0, 2]])
        self.bacthes = input_shape[0]
        self.frames =  input_shape[1]
        self.w = input_shape[2]
        self.h = input_shape[3]
        self.initial_channels = input_shape[4]
        
        self.groups_of_2 =  (input_shape[1]//2)*self.initial_channels
        self.blocks = (input_shape[2]*input_shape[3])//4

        # No of KERNELS and THE INPUT SHAPE ^^ frames x dim1 x dim2 x channels  
        default_kernel = tf.keras.initializers.GlorotUniform(seed = 0)(shape=(3, 3, 3, 
                                                self.initial_channels, self.kern_num),  dtype=precision)
        
        self.default_kernel = default_kernel
       
        
        self.filters = tf.Variable(default_kernel, trainable=True)#dtype="float32"
        #self.bias = self.add_weight(shape=[self.kern_num])
       
        filters = tf.experimental.numpy.moveaxis(self.filters, 4, 0)
        filters = tf.experimental.numpy.moveaxis(filters, 4, 1)
        filters = (tf.reverse(filters, axis = [2,3,4]))
        paddings = tf.constant([[0, 0], [0, 0], [0, 1], [0, 1], [0, 1]])
        self.useable_filters = tf.pad(filters, paddings, "CONSTANT")

    
    def getme(self):
        return self.default_kernel
   

    def trans_conv(self, args): 
        x, H = args

        #temp = self.A@x
        #temp  = temp* H
        #temp = self.A_f_T@temp
     
        x = self.A_f_T@(self.A@x * H)
       
        #x = x*H
        #return temp
        return x
   
    # Takes X frames, splits them into chunks of 2x2 and then zero pads them at the end 
    # of each row/col/frame  by 2 ,so we have N numbers of chunks of 4x4
    # Move time axis to correct position - > X(2 frames) | N(blocks) | Frame | Width | Hight 
    def img_to_blcks(self, frames):
        a1 = tf.split(frames, frames.shape[2]//2, axis = 2)
        a2 = tf.split(a1, frames.shape[1]//2, axis = 2)
        a3 = tf.reshape(a2, shape = ((tf.shape(frames)[2]//2)*(tf.shape(frames)[1]//2),  2,2))
        return a3
    
    
    def call_batches(self, inputs):
        # Called on batches of data , will call each filter as many times until over.
        #repeated on dim 1 bacthes 
        filters1 = tf.split(self.useable_filters, self.useable_filters.shape[1], axis = 1)
        filters1 = tf.repeat(filters1, (self.groups_of_2//self.initial_channels)*self.blocks, axis = 2)
        filters1 = (tf.concat(   tf.split(filters1, filters1.shape[0], axis = 0), 2 ))
        filters1 = tf.squeeze(filters1, axis= 0)
        test_inputs = tf.repeat(tf.expand_dims(inputs , axis = 0), self.kern_num, axis = 0)
        return self.kernels(test_inputs, filters1)
    

    def kernels(self, x, all_h):
        # Seperates all the channels (output) KERNELS and performs convolution in parallel on all of them.
        return tf.vectorized_map( self.conv_channel, (x, all_h))


    def conv_channel(self, args):
        # Performs convolution on a single channel.
        x, h = args
        
        x = tf.reshape(x, shape = (tf.shape(x)[0], 64, 1))
        h = tf.reshape(h, shape = (tf.shape(h)[0],  64, 1))
        
        # In case of 1 channel this is fine , BUT in case of many channels this need to be moved 1 step above
        # either at kernels or even better on call_batches so we dont do the same matmul everytime.
        # Same goes for the flattens ^^
        h = self.C_T@h
      
        return tf.vectorized_map( self.trans_conv, (x, h) )
    

    @tf.function(experimental_compile=True)
    def __call__(self, inputs):#main_call_on_batches
        #print(tf.executing_eagerly())
        # will split it into groups of 2 frames (NON overlapping)
        # We put all our inputs in serial meaning we put each channel of each "frame" one after another
        # Inputs must be of shape SIZE->DIM1->DIM2->CHANNELS
        
        inputs = tf.concat(    tf.split(inputs, inputs.shape[4], axis = 4), 1 )
        inputs = tf.experimental.numpy.moveaxis(inputs, 0, 1)
        
        img_blks = tf.vectorized_map(self.img_to_blcks, inputs)
        img_blks = tf.reshape(img_blks, (tf.shape(img_blks)[0]//2, 2, tf.shape(img_blks)[1],
                                                tf.shape(img_blks)[2], tf.shape(img_blks)[3])  )
        img_blks = tf.pad(img_blks, tf.constant([[0, 0], [0, 2], [0, 0], [0, 2], [0, 2]]), "CONSTANT")

        padded_blocks = tf.experimental.numpy.moveaxis(img_blks, 1, 2)
        padded_blocks = tf.expand_dims(padded_blocks, axis = 0)
        padded_blocks = tf.reshape(padded_blocks, (self.bacthes, self.groups_of_2*self.blocks, 4, 4, 4))
        # ABOVE NEED TO BE FIXED FOR MANY BATCHES CURRENTLY WORKS FOR ONLY 1 BATCH
        # 

        co_blocks =  tf.vectorized_map( self.call_batches,  padded_blocks)
        
        #reshape from 64 to 4x4x4
        co_blocks = tf.reshape(co_blocks, shape = (tf.shape(co_blocks)[0], tf.shape(co_blocks)[1], 
                                                  tf.shape(co_blocks)[2], 4, 4, 4))
        
        #PUSH ALL DIMENSIONS 1 to the right for batches and it was another one for the kernels
        #PEFRORMS THE BLOCKS TO IMAGE 
        co_blocks = tf.reshape(co_blocks, (self.bacthes, self.kern_num, 
                                                         self.groups_of_2, self.blocks, 4, 4, 4))
        
        #groupsof 2| frames| blocks|  <><><>   h|w or w|h DOENST MATTER
        co_blocks = tf.experimental.numpy.moveaxis(co_blocks, 4, 3)
        
        co_blocks = tf.reshape(co_blocks, 
                                      (self.bacthes, self.kern_num, 
                                       tf.shape(co_blocks)[2]*tf.shape(co_blocks)[3], 
                                       self.blocks, 4, 4))
        
        co_blocks = tf.reshape(co_blocks, 
                                  ( self.bacthes, self.kern_num, 
                                   tf.shape(co_blocks)[2], 
                                   self.w//2, self.h//2, 4, 4) )
     
        co_blocks = tf.transpose(co_blocks, [0, 1, 2, 3, 5, 4, 6])
        co_blocks = tf.reshape(co_blocks, (self.bacthes, self.kern_num, 
                                                 tf.shape(co_blocks)[2], 1, self.w*2, 
                                                 self.h*2 ) ) 
        
        co_blocks =  tf.reshape(co_blocks, (self.bacthes, self.kern_num, 
                                                  tf.shape(co_blocks)[2], tf.shape(co_blocks)[4] , 
                                                  tf.shape(co_blocks)[5]))
        
        # PERFORM OVERLAP ADD METHOD *** NOTE THAT THIS WILL OVERLAP ADD THE ACTUAL DIMENSIONS OF THE "FRAMES"
        # EFFECTIVEVELY "SQUEEZING" THE "FRAMES" AND NOT THE NUMBER OF FRAMES THEMSELVES ***
        oa = tf.reshape(co_blocks, 
                        (self.bacthes, self.kern_num, 
                         tf.shape(co_blocks)[2], tf.shape(co_blocks)[3] , tf.shape(co_blocks)[4]//4, 4))
        
        oa = tf.signal.overlap_and_add(oa, 2)
        oa = tf.transpose(oa, [0, 1, 2, 4, 3])
        oa = tf.reshape(oa, (self.bacthes, self.kern_num, tf.shape(oa)[2], tf.shape(oa)[3], tf.shape(oa)[4]//4, 4))
        oa = tf.signal.overlap_and_add(oa, 2)
        oa =  tf.transpose(oa, [0, 1, 2, 4, 3])
        
        # Split the final reults in X ,because we took the X channels of each frame and "flattened" them .
        # Also add the results of each channel together to form the final feature maps .
        maps = tf.math.reduce_sum(tf.split(oa, self.initial_channels, axis=2), axis = 0 )
        maps = tf.experimental.numpy.moveaxis(maps, 1, 4)
        
        # We got X inputs and we produced 2*X feature maps, so we need to add them in the correct order 
        # to form the correct feature maps
        # Order is : 1 + 3 | 2 + 4 | 5 + 7 | 6 + 8 and so on ...(but starting at 2)
        # Also we skip the very first and last elements and we start counting from the 3rd element
        # Also the 2 nd element is the 1st of our desired output in case of odd length input
        # And in case of even input we need to append the second last element 
        a_OPT_INDEX = maps[:, 2::2]
        b_OPT_INDEX  = maps[:, 3::2]

        a_OPT_INDEX = a_OPT_INDEX[:, 0:-1]
        b_OPT_INDEX = b_OPT_INDEX[:, 0:-1]

        a_OPT_INDEX = tf.split(a_OPT_INDEX, a_OPT_INDEX.shape[1]//2 , axis = 1)
        b_OPT_INDEX = tf.split(b_OPT_INDEX, b_OPT_INDEX.shape[1]//2 , axis = 1)

        pre_mid_maps = tf.concat([tf.expand_dims(a_OPT_INDEX, axis = 1), tf.expand_dims(b_OPT_INDEX, axis = 1)], axis=1)
        pre_mid_maps = (tf.math.reduce_sum(pre_mid_maps, axis = 3))

        mid_maps =tf.reshape(pre_mid_maps, 
                             (tf.shape(pre_mid_maps)[0]*tf.shape(pre_mid_maps)[1],
                              tf.shape(pre_mid_maps)[2],tf.shape(pre_mid_maps)[3],tf.shape(pre_mid_maps)[4],
                             tf.shape(pre_mid_maps)[5]))

        mid_maps = tf.experimental.numpy.moveaxis(mid_maps, 0, 1)
        maps = tf.concat([tf.expand_dims( maps[:,1], axis =1), mid_maps, tf.expand_dims( maps[:, -2], axis =1)], axis = 1)
        
        return maps[ :, :,  1:-1, 1:-1, :]/64# C matrice is *64 so we must divide the result.


In [None]:

precision = tf.float32

curr_path = "/content/drive/MyDrive/fast3/"
A = np.load(curr_path+'A.npy')
C = np.load(curr_path+'C.npy')

C = np.flipud(C)
A_f = np.fliplr(A)
A_f_T = A_f.T
C_T = C.T

A = tf.convert_to_tensor(A, dtype=precision)
C = tf.convert_to_tensor(C, dtype=precision)
A_f_T = tf.convert_to_tensor(A_f_T, dtype=precision)
C_T = tf.convert_to_tensor(C_T, dtype=precision)  


In [None]:

kernels = 1
batches = 1#broken after last changes
frames = 100#must be greater than 4
dim1 = 128
dim2 = 128
channels = 1

video = np.ones(shape = ( batches, frames, dim1, dim2, channels))
video = tf.convert_to_tensor(video, dtype=precision)



In [None]:

layerMY = myConv( no_kernels = kernels ,input_shape =( batches, frames, dim1, dim2, channels) )

filters = layerMY.getme()
filters = tf.convert_to_tensor(filters, dtype=precision)
filters = tf.reshape(filters, ( 3, 3, 3, channels, kernels))

x = layerMY(video)
x1 = tf.nn.conv3d(video, filters, strides = (1,1,1,1,1), padding = "SAME")


value = PSNR(x1, x)
value = np.array(value)   
print(f"PSNR value is {value} dB")
#psnr1 = tf.image.psnr(x, x1, max_val=tf.reduce_max(tf.math.maximum(x, x1)))
#print(psnr1)

In [None]:
kernels = 1
batches = 1#broken after last changes
frames = 56#must be greater than 4
dim1 = 18
dim2 = 18
channels = 1

video = np.ones(shape = ( batches, frames, dim1, dim2, channels))
video = tf.convert_to_tensor(video, dtype=precision)

layerMY = myConv( no_kernels = kernels ,input_shape =( batches, frames, dim1, dim2, channels) )
layerTE = tf.keras.layers.Conv3D(kernels, (3,3,3),strides = (1,1,1), 
                                 padding = "same", input_shape=( batches,frames, dim1, 
                                                                dim2, channels), use_bias=False)

In [None]:

def my():
  x = layerMY(video)
  return x
my()
my_time = timeit.timeit('my()', number=1000, setup="from __main__ import my")

print(my_time)

In [None]:

def tens():
  x = layerTE(video)
  return x
tens()
tens_time = timeit.timeit('tens()', number=1000, setup="from __main__ import tens")

print(tens_time)