In [None]:
import torch
from torchvision.ops import RoIAlign
from torchvision.models import resnet101
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
import pickle
import argparse
import json

#Feature Extractor Function
It employes a visual model, e.g., ResNet101, to extract visual feature from an image and a list of ROIs passed as input arguments.

In [None]:
def extract_roi_features(image_path, rois):
    roi_align = RoIAlign(output_size=(1, 1), spatial_scale=1/32, sampling_ratio=-1)
    image = Image.open(image_path).convert('RGB')
    preprocess = transforms.Compose([
        transforms.Resize((800, 800)),  # resize the image to (800, 800) for simplicity
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    input_tensor = preprocess(image)
    input_batch = input_tensor.unsqueeze(0)
    input_batch = input_batch.to(device, dtype = torch.float)
    with torch.no_grad():
        res5_features = resnet(input_batch)
    image_indices = torch.zeros((len(rois), 1))
    image_indices = image_indices.to(device, dtype = torch.int)
    rois = torch.tensor(rois).float()
    rois = rois.to(device, dtype = torch.float)
    rois = torch.cat((image_indices, rois), dim=1)
    pooled_features = roi_align(res5_features, rois)
    roi_features = pooled_features.reshape(-1, 2048)
    return roi_features.tolist()

#Main

In [None]:
parser = argparse.ArgumentParser()
parser.add_argument("--split", type=str, default="train", choices=["train", "test"])
args = parser.parse_args()
split = args.split

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
resnet = resnet101(pretrained=True)
resnet = torch.nn.Sequential(*list(resnet.children())[:-2])
resnet.to(device)
resnet.eval()

Read entity annotated FUNSD data. They can be downloaded from: https://guillaumejaume.github.io/FUNSD/download/

In [None]:
with open("data/FUNSD/dataset/"+split+"ing_data/all_annotations.json", 'r') as f:
    data = json.load(f)

with open("data/FUNSD/dataset/"+split+"ing_data/all_annotations.json", 'r') as f:
    copy = json.load(f)

Apply feature extractor on th entire dataset and a field "visual_list" to the original json file

In [None]:
for key in tqdm(data.keys()):
    image_path = "data/FUNSD/dataset/"+split+"ing_data/images/"+key+".png"
    doc = data[key]["form"]
    object_list = []
    for i in range(len(doc)):
        object_list.append(doc[i]["box"])
        if len(doc[i]["box"]) == 0:
            continue
    features = extract_roi_features(image_path, object_list)
    copy[key]["visual_list"] = features

Save the new data structure containing the computed visual features as pickle file.

In [None]:
with open('data/FUNSD/dataset/'+split+'ing_data/all_annotations_visual.pickle', 'wb') as f:
    pickle.dump(copy, f, protocol=pickle.HIGHEST_PROTOCOL)