In [None]:
# !unzip html2json.zip

In [1]:
import torch
from torch import nn
from functools import partial
from html2json import HTML_JSON_Dataset, padding_collate_fn
from torch.utils.data import DataLoader, random_split
from html2json.charactertokenizer import HTMLTokenizer, JSONTokenizer
from html2json.charactertokenizer import MASK_TOKEN
from html2json import load_data
from html2json.seq2seq import Seq2SeqTransformer
from html2json.seq2seq import translate
from html2json.training import train_epoch, evaluate
from timeit import default_timer as timer
import os

In [2]:
torch.cuda.empty_cache()
torch.cuda.is_available()

True

In [3]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
html_pth = 'generated_tables/tables'
json_pth = 'generated_tables/metadata'

In [5]:
html_data, json_data = load_data(html_pth, json_pth, as_string=False, limit=None)

In [6]:
html_tokenizer = HTMLTokenizer(html_data)
json_tokenizer = JSONTokenizer(json_data)



In [7]:
collate_fn = partial(padding_collate_fn, pad_token_html = MASK_TOKEN, pad_token_json = MASK_TOKEN)
html_data_str, json_data_str = load_data(html_pth, json_pth, as_string=True, limit=None)
h2j_dataset = HTML_JSON_Dataset([html_tokenizer.encode(h) for h in html_data_str], [json_tokenizer.encode(j) for j in json_data_str])

In [8]:
train_set, val_set = random_split(h2j_dataset, [0.8, 0.2], torch.Generator().manual_seed(42))

In [None]:
torch.manual_seed(42)
SRC_VOCAB_SIZE = len(html_tokenizer)
TGT_VOCAB_SIZE = len(json_tokenizer)
EMB_SIZE = 256
NHEAD = 8
FFN_HID_DIM = 4096
BATCH_SIZE = 32
NUM_ENCODER_LAYERS = 1
NUM_DECODER_LAYERS = 1
LR = 0.001
NUM_EPOCHS = 25

torch.cuda.empty_cache()
train_dataloader = DataLoader(train_set, batch_size=BATCH_SIZE, collate_fn=collate_fn)
validation_dataloader = DataLoader(val_set, batch_size=BATCH_SIZE, collate_fn=collate_fn)

In [None]:
transformer = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE,
                                 NHEAD, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, FFN_HID_DIM)

if os.path.exists("./assets/transformer.pt"):
    transformer.load_state_dict(torch.load("./assets/transformer.pt", map_location=torch.device(DEVICE)))
else:
    for p in transformer.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
transformer = transformer.to(DEVICE)

In [None]:
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=MASK_TOKEN)
optimizer = torch.optim.Adam(transformer.parameters(), lr=LR, betas=(0.9, 0.98), eps=1e-9)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=2, threshold=0.1, threshold_mode='rel')

In [14]:
for epoch in range(1, NUM_EPOCHS+1):
    start_time = timer()
    transformer.train()
    train_loss = train_epoch(transformer, optimizer, train_dataloader, loss_fn)
    end_time = timer()
    scheduler.step(train_loss)
    # evaluation
    transformer.eval()
    val_loss = evaluate(transformer, validation_dataloader, loss_fn)
    # add save model checkpoint every 20 epochs
    if epoch % 5 == 0:
        torch.save({
                'epoch': epoch,
                'model_state_dict': transformer.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': train_loss,
                }, f"./checkpoints/checkpoint_{epoch}.pt")
    # val_loss = evaluate(transformer)
    
    print(f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}, "f"Epoch time = {(end_time - start_time):.3f}s, lr: {scheduler.get_last_lr()}")
# save the model after training
torch.save(transformer.state_dict(), "./assets/transformer.pt")



Epoch: 1, Train loss: 4.788, Val loss: 4.303, Epoch time = 7.219s, lr: [0.001]


In [None]:
_, val_str_set = random_split(json_data_str, [0.8, 0.2], torch.Generator().manual_seed(42))

In [None]:
val_idx = val_str_set[3]

In [19]:
pred = translate(transformer, html_data_str[val_idx], html_tokenizer, json_tokenizer)

  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)


In [20]:
json_data_str[val_idx]

'[{]["body"][:][{]["content"][:][[]"1060"[,]"37"[,]"1593"[,]"1364"[]][,]["headers"][:][{]["col"][:][[]"Roberts LLC"[]][,]["row"][:][[]"Daniel Brown"[,]"Shane Barnes DDS"[,]"Nicole Carpenter"[,]"Kristin Duarte"[,]"programmer"[,]"Carpenter"[,]"singer"[,]"actor"[]][}][}][,]["footer"][:][{]["table_creation_date:"][:]"3Feb2013"[,]["text"][:]"modified: 5Feb2013\\nCreation: 3Feb2013 Chad"[}][,]["header"][:][{]["table_id"][:]"59.99.9.62"[,]["text"][:]"Table 59.99.9.62 Loss adjuster, chartered"[}][}]'

In [21]:
pred[5:-5]

'[CLS]J3FFFFFDQ33FFFFFKB40["row"]qFS3VFS3VFFFFS3FS3FS3VFS3FKB40FFFFFFF[,]F[,]FFFFK["table_id"]/2u: :v3FK["table_id"]J3FK["row"]qFS3FS3FS3FKB4: : : :[,]FS\n:[,]R[,]FS\n:P:S3FS3\n:P:S33VFS\n:VFS3FS\nz:S\npJ36["row"]qw3\ndMz:V["table_creation_date:"]pg\n45bFS\ndM["table_creation_date:"]\ndM["table_creation_date:"]VFS\ndM["table_creation_date:"]["table_creation_date:"]["table_creation_date:"]V["table_creation_date:"]V["table_creation_date:"]["table_creation_date:"]["table_creation_date:"]["table_creation_date:"]pgf[,]vzW["table_creation_date:"]D["table_creation_date:"]D["table_creation_date:"]D["table_creation_date:"]D["table_creation_date:"]VFK["table_id"]/fsFS\ndMz:VFD2[SEP]'

In [22]:
pred[5:-5] == json_data_str[val_idx]

False