In [1]:
import sys
import os
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
    

from pathlib import Path

from transformers import AutoTokenizer

from src.domlm import DOMLMConfig

config = DOMLMConfig.from_json_file('../domlm-config/config.json')
tokenizer = AutoTokenizer.from_pretrained("roberta-base")

dataset_dir = Path('D:/projects/Forage/GenWeb/webpage_information_extraction/dataset/SWDE_Dataset/')
domain = 'restaurant'
groundtruth_dir = dataset_dir / 'groundtruth' / domain
webpages_dir = dataset_dir / 'webpages'
domain_dir = webpages_dir / domain
website_dir = domain_dir / 'restaurant-frommers(2000)'

In [2]:
from src.domlm.modeling_domlm import DOMLMForTokenClassification

In [4]:
model = DOMLMForTokenClassification(config)

In [5]:
model.forward()

In [2]:
def extract_labels(label_files):
    label_info = {}
    for file in label_files:
        label = file.name.split('-')[-1].replace('.txt', '')
        with open(file, 'r') as f:
            content = f.readlines()
            for line in content[2:]:
                page_id = line.split('\t')[0]
                if page_id not in label_info:
                    label_info[page_id] = {}
                nums = line.split('\t')[1]
                value = line.split('\t')[2].strip()
                label_info[page_id][label] = {
                    'nums': nums,
                    'value': value,
                }
    return label_info

In [12]:
label_files = ['restaurant-frommers-address.txt',
               'restaurant-frommers-cuisine.txt',
               'restaurant-frommers-name.txt',
               'restaurant-frommers-phone.txt']
label_files = [groundtruth_dir / l for l in label_files]

label_infos = extract_labels(label_files)


In [15]:
path = list(website_dir.glob('*.htm'))[0]
html_string = open(path, 'r').read()

label2text = label_infos[path.name.split('.')[0]]
text2label = {v['value']: {'label':k, 'nums':v['nums']} for k,v in label2text.items()}

In [50]:
from src.html_utils import get_cleaned_body
from src.preprocess import (extract_token_for_nodes, assign_label_for_nodes, generate_subtrees, postorder, _tokens_len)

In [53]:
m = tokenizer.model_max_length
s = 128

padding_idxs = {
        "node_ids": config.node_pad_id,
        "parent_node_ids":config.node_pad_id,
        "sibling_node_ids": config.sibling_pad_id,
        "depth_ids": config.depth_pad_id,
        "tag_ids": config.tag_pad_id,
        # "position_ids": config.pad_token_id,
        "input_ids": tokenizer.pad_token_id,
        "attention_mask": 0,
        "labels": -100
    }

dom = get_cleaned_body(html_string)
token_repr = extract_token_for_nodes(dom)
node2label = assign_label_for_nodes(dom, text2label)

subtrees = generate_subtrees(dom, token_repr, max_nodes, stride) # requires tokenizer

In [55]:
pre_order = list(dom.iter())
post_order = list(postorder(dom))

In [56]:
subtrees = []

### init first subtree
new = []
node_ids = {}
for i,el in enumerate(pre_order):
    if _tokens_len(new, token_repr) >= m:
        break
    new.append(el)  
    node_ids[el] = i

107


In [64]:
tokenizer.__class__.__name__

'RobertaTokenizerFast'

In [None]:
_tokens_len(new, token_repr)

<Element tr at 0x1f1b7b01bd0>

In [None]:



while len(new) != 0:
    visited = [n for n,idx in node_ids.items() if idx < node_ids[new[0]] ]
    total_len = _tokens_len(visited + new, token_repr)
    ###prune postorder
    for el in post_order:
        if el in new or total_len <= m:        
            break
        else:
            try:
                visited.remove(el)
                total_len -= _tokens_len([el], token_repr)
            except ValueError:
                pass
    ### prune root
    el_last = None
    while total_len > m:
        if len(visited) != 0:            
            el_root = visited[0]                
            num_child = sum(child in new or child in visited for child in el_root)        
        else:
            num_child = 2 # pop from new if there are no nodes in visited left

        if num_child < 2:
            total_len -= _tokens_len([el_root], token_repr)
            visited.pop(0)            
        else:
            el_last = new.pop()
            total_len -= _tokens_len([el_last], token_repr)
    t = visited + new
    subtrees.append(t)

    el_last = new[-1]
    ### expand subtree
    new = []
    next_idx = node_ids[el_last] + 1
    for i,el in enumerate(pre_order[next_idx:],start=next_idx):
        if _tokens_len(new, token_repr) >= s:
            break
        new.append(el)        
        node_ids[el] = i    
