In [106]:
import numpy as np
import torch
import tfrecord
import cv2
from time import process_time
import tensorflow as tf
from tfrecord.torch.dataset import TFRecordDataset
import os
from feature_CNN import FeatureNet_v1

In [107]:
def visibility_matrix(torch_df,num_words):
    '''indentify neighbours to the right and down and generate visibility matrix / neighbourhood graph.
        for each node, we indentify it's closest neighbour to the right and the closest neighbour below.
    input: numpy array of shape (words, [x1, x2, y1, y2])
    output: visibility matrix of shape (words, words)'''
    
    #remove last column (word_length)
    npdf = torch_df.numpy()
    
    #Only create matrix of size matching number of words
    matrix = np.zeros((num_words, num_words))

    for i,row1 in enumerate(npdf):
        if i == num_words:
            break

        #xmin = 0
        #ymin = 1
        #xmax = 2
        #ymax = 3 

        min_down = 10**6
        min_right = 10**6
        min_down_idx = None
        min_right_idx = None

        for j,row2 in enumerate(npdf):
            if j == num_words:
                break
            if i != j:
                #Right neighbour
                if row1[1] <= row2[1] <= row1[3] or row1[1] <= row2[3] <= row1[3] or row2[1] <= row1[1] <= row2[3] or row2[1] <= row1[3] <= row2[3]:
                    if  0 <= row2[0]-row1[2] <= min_right:
                        min_right_idx, min_right = j, row2[0]-row1[2]

                #Down neighbour
                if row1[0] <= row2[0] <= row1[2] or row1[0] <= row2[2] <= row1[2] or row2[0] <= row1[0] <= row2[2] or row2[0] <= row1[2] <= row2[2]:
                    if 0 <= row2[1]-row1[3] <= min_down:
                        min_down_idx, min_down = j, row2[1]-row1[3]

        if min_right_idx != None:
            matrix[i,min_right_idx] = 1
            matrix[min_right_idx, i] = 1    
        if min_down_idx != None:
            matrix[i,min_down_idx] = 1
            matrix[min_down_idx, i] = 1
            
    source = []
    target = []

    for i, row in enumerate(matrix):
        for j, edge in enumerate(row):
            if edge == 1:
                source.append(i)
                target.append(j)

    edge_index = torch.tensor([source, target], dtype=torch.long)

    return edge_index


def tfrecord_transforms(elem,
                   device,
                   max_height = 768,
                   max_width = 1366,
                   num_of_max_vertices = 250,
                   max_length_of_word = 30,
                   batch_size = 8):
    """
    Function used to transform the data loaded by the TFRecord dataloader.
    Parameters are defind in TIES datageneration, defines the size and complexity of the generated tables. DO NOT CHANGE  
    """
    reshape = 0
    xnumwords = 0
    feat_reshap = 0
    visimat = 0
    adjmats = 0

    with torch.no_grad():
        #Everything is flattened in tfrecord, so needs to be reshaped. 

        #Images are in range [0,255], need to be in [0,1]
        #If image max is over 1 , then normalize: 
        data_dict =  {}

        
        #Torch dimensions: B x C x H x W
        #inputting grayscale, so only 1 dimension
        t = process_time()
        if torch.max(elem['image']) > 1:
            data_dict['imgs'] = (elem['image']/255).reshape(batch_size,1,max_height,max_width).to(device)
        else:
            data_dict['imgs'] = elem['image'].reshape(batch_size,1,max_height,max_width).to(device)
        reshape+=process_time()-t

        #Extract number of words for each image:
        t = process_time()
        num_words = elem['global_features'][:,2]
        data_dict['num_words'] = num_words.to(device)
        xnumwords += process_time()-t
        
        t = process_time()
        v = elem['vertex_features'].reshape(batch_size,num_of_max_vertices,5).float()
        feat_reshap += process_time()-t
        #normalizaing words coordinates to be invariant to image size 
        v[:,:,0] = v[:,:,0]/max_width
        v[:,:,1] = v[:,:,1]/max_height
        v[:,:,2] = v[:,:,2]/max_width
        v[:,:,3] = v[:,:,3]/max_height

        #data_dict['vertex_features'] = v

        vertex_feats = []
        for idx,vf in enumerate(v):
            tmp = vf[0:num_words[idx]].to(device)
            #tmp.requires_grad=True
            vertex_feats.append(tmp)

        data_dict['vertex_features'] = vertex_feats  
                
        #Calculate visibility matrix for each batch element
        t = process_time()
        edge_index = []
        for idx,vex in enumerate(v):
            edge_index.append(visibility_matrix(vex,num_words[idx]).to(device))
        visimat += process_time()-t
         
        data_dict['edge_index'] = edge_index

        
        adj_cells = []
        adj_cols = []
        adj_rows = []
        for idx,nw in enumerate(num_words):
            adj_cells.append(elem['adjacency_matrix_cells'][idx].reshape(num_of_max_vertices,num_of_max_vertices)[:nw][:nw].to(device))
            adj_cols.append(elem['adjacency_matrix_cols'][idx].reshape(num_of_max_vertices,num_of_max_vertices)[:nw][:nw].to(device))
            adj_rows.append(elem['adjacency_matrix_rows'][idx].reshape(num_of_max_vertices,num_of_max_vertices)[:nw][:nw].to(device))

        data_dict['adjacency_matrix_cells'] = adj_cells
        data_dict['adjacency_matrix_cols'] = adj_cols
        data_dict['adjacency_matrix_rows'] = adj_rows
        

        
        print(f'#####TRANSFORMS: reshape: {reshape}, extract number of words: {xnumwords}, feat_reshape: {feat_reshap}, visibility matrix: {visimat}, adjacency matrix: {adjmats}')

        return data_dict

In [108]:
#variables for tfrecord loader
batchsize = 8
index_path = None
tfrecord_description = {"image": "float", 
               "global_features": "int",
               "vertex_features": "int",
               "adjacency_matrix_cells":"int",
               "adjacency_matrix_cols":"int",
               "adjacency_matrix_rows":"int",
               "vertex_text":'int'}

#Load list of tfRecords from folder: 
folder_path = os.getcwd()+r'\tfrecords'
#folder_path = "C:\Users\Jesper\Desktop\DataGeneration\Data_Outputs"

#load filenames of folder: 
tfrecord_files = os.listdir(folder_path)

In [109]:
folder_path

'C:\\Users\\Jesper\\Desktop\\TableRecognition\\Table_Detection_and_Recognition\\Table_Recognition\\tfrecords'

In [110]:
record = tfrecord_files[0]
tfrecord_path = os.path.join(folder_path,record)
dataset = TFRecordDataset(tfrecord_path, index_path, tfrecord_description)
loader = torch.utils.data.DataLoader(dataset, batch_size=batchsize)

In [111]:
batch = next(iter(loader))

In [112]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
#load Feature CNN model
featurenet_path = os.getcwd()+r"\models\FeatureNet_v1.pt"
featurenet = FeatureNet_v1()
featurenet.load_state_dict(torch.load(featurenet_path,map_location=torch.device('cpu')))
featurenet.eval()
featurenet.to(device)

FeatureNet_v1(
  (conv_1): Conv2d(1, 8, kernel_size=(3, 3), stride=(1, 1))
  (conv_2): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1))
  (pool_1): MaxPool2d(kernel_size=5, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv_3): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1))
  (pool_2): AvgPool2d(kernel_size=5, stride=3, padding=0)
  (dropout): Dropout2d(p=0.5, inplace=False)
  (l_1): Linear(in_features=888832, out_features=128, bias=True)
  (l_out): Linear(in_features=128, out_features=4, bias=False)
)

In [27]:
batch['image'].shape

torch.Size([8, 1049088])

In [113]:
data_dict = tfrecord_transforms(batch,device=device,batch_size=8)

#####TRANSFORMS: reshape: 0.78125, extract number of words: 0.0, feat_reshape: 0.0, visibility matrix: 5.21875, adjacency matrix: 0


In [114]:
feature_map = featurenet.feature_forward(data_dict['imgs'])

In [136]:
feature_map[0].shape[0]

torch.Size([32, 124, 224])

In [283]:
pxl_l, pxl_h = 1/feature_map.shape[2], 1/feature_map.shape[3]
max_l, max_h = 0,0
max_x, max_y = feature_map.shape[2], feature_map.shape[3]

for batch in range(batchsize):
    x1, y1, x2, y2 = data_dict['vertex_features'][batch][:,0].cpu().numpy(), data_dict['vertex_features'][batch][:,1].cpu().numpy(), data_dict['vertex_features'][batch][:,2].cpu().numpy(), data_dict['vertex_features'][batch][:,3].cpu().numpy()
    l, h = x2-x1,y2-y1
    if max_l < np.max(l):
        max_l = np.max(l)
    if max_h < np.max(h):
        max_h = np.max(h)
        
max_l, max_h = np.ceil(max_l/pxl_l),np.ceil(max_h/pxl_h)

gathered_feats = []
for batch in range(batchsize):
    all_feats_img = []
    for idx, word in enumerate(data_dict['vertex_features'][batch]):
        x1, y1, x2, y2, _ = word.cpu().numpy()
        l, h = x2-x1,y2-y1
        c = np.floor(((x1+l)/pxl_l,(y1+h)/pxl_h))

        x_slice = int(c[0]-np.floor(max_l/2)),int(c[0]+np.ceil(max_l/2))
        if min(x_slice) < 0:
            x_slice = (x_slice[0]+abs(min(x_slice)),x_slice[1]+abs(min(x_slice)))
        if max(x_slice) > max_x:
            x_slice = (x_slice[0]-(max(x_slice)-max_x),x_slice[1]-(max(x_slice)-max_x))
            
        y_slice = int(c[1]-np.floor(max_h/2)),int(c[1]+np.ceil(max_h/2))
        if min(y_slice) < 0:
            y_slice = (y_slice[0]+abs(min(y_slice)),y_slice[1]+abs(min(y_slice)))
        if max(y_slice) > max_y:
            y_slice = (y_slice[0]-(max(y_slice)-max_y),y_slice[1]-(max(y_slice)-max_y))
        
        all_feats_w = word
        for layer in range(feature_map.shape[1]):
            featmapfeats = feature_map[batch][layer][x_slice[0]:x_slice[1],y_slice[0]:y_slice[1]]
            all_feats_w = torch.cat((all_feats_w,torch.flatten(featmapfeats)))
        all_feats_img.append(all_feats_w)

    gathered_feats.append(torch.stack(all_feats_img,dim=0))
