In [None]:
# !unzip html2json.zip

In [None]:
!pip install evaluate

In [8]:
import torch
from torch import nn
from functools import partial
from html2json import HTML_JSON_Dataset, padding_collate_fn
from torch.utils.data import DataLoader, random_split
from html2json.charactertokenizer import HTMLTokenizer, JSONTokenizer
from html2json.charactertokenizer import MASK_TOKEN
from html2json import load_data
from html2json.seq2seq import Seq2SeqTransformer
from html2json.seq2seq import translate_greedy_search, translate_beam_search
from html2json.training import train_epoch, evaluate
from timeit import default_timer as timer
from evaluate import load
import os

In [9]:
torch.cuda.empty_cache()
torch.cuda.is_available()

True

In [10]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [11]:
html_pth = 'generated_tables/tables'
json_pth = 'generated_tables/metadata'

In [12]:
html_data, json_data = load_data(html_pth, json_pth, as_string=False, limit=None)

In [13]:
html_tokenizer = HTMLTokenizer(html_data)
json_tokenizer = JSONTokenizer(json_data)



In [14]:
collate_fn = partial(padding_collate_fn, pad_token_html = MASK_TOKEN, pad_token_json = MASK_TOKEN)
html_data_str, json_data_str = load_data(html_pth, json_pth, as_string=True, limit=None)
h2j_dataset = HTML_JSON_Dataset([html_tokenizer.encode(h) for h in html_data_str], [json_tokenizer.encode(j) for j in json_data_str])

In [15]:
train_set, val_set = random_split(h2j_dataset, [0.8, 0.2], torch.Generator().manual_seed(42))

In [9]:
torch.manual_seed(42)
SRC_VOCAB_SIZE = len(html_tokenizer)
TGT_VOCAB_SIZE = len(json_tokenizer)
EMB_SIZE = 256
NHEAD = 8
FFN_HID_DIM = 4096
BATCH_SIZE = 32
NUM_ENCODER_LAYERS = 1
NUM_DECODER_LAYERS = 1
LR = 0.001
NUM_EPOCHS = 25

torch.cuda.empty_cache()
train_dataloader = DataLoader(train_set, batch_size=BATCH_SIZE, collate_fn=collate_fn)
validation_dataloader = DataLoader(val_set, batch_size=BATCH_SIZE, collate_fn=collate_fn)

In [10]:
transformer = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE,
                                 NHEAD, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, FFN_HID_DIM)

if os.path.exists("./assets/transformer.pt"):
    transformer.load_state_dict(torch.load("./assets/transformer.pt", map_location=torch.device(DEVICE)))
else:
    for p in transformer.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
transformer = transformer.to(DEVICE)

  transformer.load_state_dict(torch.load("./assets/transformer.pt", map_location=torch.device(DEVICE)))


In [None]:
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=MASK_TOKEN)
optimizer = torch.optim.Adam(transformer.parameters(), lr=LR, betas=(0.9, 0.98), eps=1e-9)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=2, threshold=0.1, threshold_mode='rel')

In [14]:
for epoch in range(1, NUM_EPOCHS+1):
    start_time = timer()
    transformer.train()
    train_loss = train_epoch(transformer, optimizer, train_dataloader, loss_fn)
    end_time = timer()
    scheduler.step(train_loss)
    # evaluation
    transformer.eval()
    val_loss = evaluate(transformer, validation_dataloader, loss_fn)
    # add save model checkpoint every 20 epochs
    if epoch % 5 == 0:
        torch.save({
                'epoch': epoch,
                'model_state_dict': transformer.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': train_loss,
                }, f"./checkpoints/checkpoint_{epoch}.pt")
    # val_loss = evaluate(transformer)
    
    print(f"Epoch: {epoch}, Train loss: {train_loss:.5f}, Val loss: {val_loss:.5f}, "f"Epoch time = {(end_time - start_time):.3f}s, lr: {scheduler.get_last_lr()}")
# save the model after training
torch.save(transformer.state_dict(), "./assets/transformer.pt")



Epoch: 1, Train loss: 4.788, Val loss: 4.303, Epoch time = 7.219s, lr: [0.001]


In [6]:
train_idx, val_idx = random_split(range(len(h2j_dataset)), [0.8, 0.2], torch.Generator().manual_seed(42))

NameError: name 'h2j_dataset' is not defined

In [5]:
sample_num = 0
idx = val_idx[sample_num]
val_idx = html_data_str[idx]

NameError: name 'val_idx' is not defined

In [13]:
val_idx

'<table><caption>Table 28.13.13.94 Broadcast presenter</caption><thead><tr><th></th><th>James Jones</th><th>Gerald Kelley</th><th>James Potter</th><th>Alexander Hill</th><th>Ryan Smith</th><th>Brandon Martin</th><th>Erin Dickson</th></tr></thead><tbody><tr><td>Harrison, Richardson and Wilson</td><td>579</td><td>1334</td><td>654</td><td>1194</td><td>1184</td><td>1362</td><td>682</td></tr><tr><td>Reed LLC</td><td>504</td><td>849%</td><td>701%</td><td>965</td><td>421</td><td>286%</td><td>177%</td></tr><tr><td>Klein LLC</td><td>1153</td><td>85</td><td>1041</td><td>554</td><td>900</td><td>1435%</td><td>210%</td></tr><tr><td>Gonzalez Inc</td><td>601%</td><td>1461</td><td>1145</td><td>1586</td><td>1192</td><td>1205</td><td>1101</td></tr><tr><td>Watson, Brown and Long</td><td>1154%</td><td>1303%</td><td>1334%</td><td>1046</td><td>648</td><td>468</td><td>157</td></tr></tbody><tfoot>Creation: 12Jul2008 Fiji</tfoot></table>\n'

In [37]:
pred = translate_greedy_search(transformer, val_idx, html_tokenizer, json_tokenizer)

In [38]:
json_data_str[idx]

'[{]["body"][:][{]["content"][:][[]"1560"[,]"1012"[,]"694"[,]"800"[,]"240"[,]"1371"[,]"314%"[,]"342%"[,]"1204%"[,]"189%"[,]"1536"[,]"1349%"[]][,]["headers"][:][{]["col"][:][[]"Lopez-Foster"[,]"Bolton-Thompson"[]][,]["row"][:][[]"Nancy Vasquez"[,]"Jason Martinez"[,]"Brittany Mcbride"[,]"Victor Phillips"[,]"Matthew Luna"[,]"Tina Smith"[]][}][}][,]["footer"][:][{]["table_creation_date:"][:]"20Jan2022"[,]["text"][:]"Creation: 20Jan2022 Madagascar"[}][,]["header"][:][{]["table_id"][:]"29"[,]["text"][:]"Table 29 Volunteer coordinator"[}][}]'

In [39]:
pred[5:-5]

'[}]["table_id"]["row"][}][{]["row"]["table_creation_date:"]FfSLF["text"]FdaSLF["text"]FdfLF["text"]FSdPLF["text"]FdSLF["text"]FdSLF["text"]FdSLF["text"]FdaLF["text"]FfSdF["text"]FaLF["text"]FaakF["text"]FncbqW,WnXF["text"]F1(q/7pW1(QF["text"]F1cc7W1(F["text"]F1cQc)pW1pLF["text"]F1\'LF["text"]F1pQqpW1c\'F["text"]Fhp/DqpF["text"]F1p\'F["text"]F1p\'F["text"]F1pxxpW15pF["text"]F1p\'F["text"]F1pxpiple\nc(vep47e,c7F["text"]F1(QcF["text"]F1(QcF["text"]F1pQqcennXF["text"]F1(QcF["text"]F1(pF["text"]F1(pF["text"]F1p\'F["text"]F1pQqpW1(F["text"]F1pQc/ep47e\nF["text"]F1pQqp/YF["text"]F1pDtpiF["text"]F1p\'F["text"]F1p\'F["text"]F1p\'F["text"]F1p\'F["text"]F1p\'F["text"]F1pxxpe\nenF["text"]F1\'p4pipe\nD54F["text"]F1pxpe64DF["text"]F\npicenx7F["text"]F\ncQe1qDcF["text"]F1pQqpe1cc7F["text"]F1pQqpe1cc7F["text"]F1pQqpe1cc7F["text"]F1pQqpe1(F["text"]F1p\'F["text"]F1pQqpe1(F["text"]F1pQqpe1pD5F["text"]F1p\'F["text"]F1p\'F["text"]F1p\'F["text"]F1p\'F["text"]F1p\'F["text"]F1p\'F["text"]F1p\'F["text"]F1p\'F

In [22]:
pred[5:-5] == json_data_str[idx]

False

In [13]:
translate_beam_search(transformer, val_idx, html_tokenizer, json_tokenizer)

  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)


RuntimeError: shape '[481, 16, 32]' is invalid for input of size 123136

In [2]:
bleu = load("bleu")

Collecting evaluate
  Obtaining dependency information for evaluate from https://files.pythonhosted.org/packages/a2/e7/cbca9e2d2590eb9b5aa8f7ebabe1beb1498f9462d2ecede5c9fd9735faaf/evaluate-0.4.3-py3-none-any.whl.metadata
  Downloading evaluate-0.4.3-py3-none-any.whl.metadata (9.2 kB)
Collecting datasets>=2.0.0 (from evaluate)
  Obtaining dependency information for datasets>=2.0.0 from https://files.pythonhosted.org/packages/be/3e/e58d4db4cfe71e3ed07d169af24db30cfd582e16f977378bd43fd7ec1998/datasets-3.0.1-py3-none-any.whl.metadata
  Downloading datasets-3.0.1-py3-none-any.whl.metadata (20 kB)
Collecting dill (from evaluate)
  Obtaining dependency information for dill from https://files.pythonhosted.org/packages/c9/7a/cef76fd8438a42f96db64ddaa85280485a9c395e7df3db8158cfec1eee34/dill-0.3.8-py3-none-any.whl.metadata
  Using cached dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from evaluate)
  Obtaining dependency information for xxhash from https://files.pythonhosted.org/


[notice] A new release of pip is available: 23.2.1 -> 24.2
[notice] To update, run: python.exe -m pip install --upgrade pip


In [None]:
predictions = [translate_greedy_search(transformer, html_data_str[idx], html_tokenizer, json_tokenizer) for i, idx in enumerate(val_idx) if i <= 10]
references = [json_data_str[idx] for i, idx in enumerate(val_idx) if i <= 10]

In [3]:
bleu.compute(predictions=predictions, references=references)

Downloading builder script:   0%|          | 0.00/5.94k [00:00<?, ?B/s]

Downloading extra modules:   0%|          | 0.00/1.55k [00:00<?, ?B/s]

Downloading extra modules:   0%|          | 0.00/3.34k [00:00<?, ?B/s]