In [1]:
# !pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
# !pip install transformers

In [2]:
import torch
torch.cuda.is_available()

True

In [3]:
torch.cuda.empty_cache()

In [4]:
# In this notebook we train a translation model from data in html format to json format.
# the model architecture is a transformer model with an encoder-decoder architecture.
# the model is charcter based and is trained on a dataset of html files and their corresponding json files.
# The steps are as follows:
# 1. remove all attributes from html tags as they are not needed for translation
# 2. create a tokenizer for the html data where each tag is a token along with the characters inside the tags.
# 3. create a tokenizer for the json data where each key is a token along with the characters inside the values.
# 4. create a transformer model with an encoder-decoder architecture.
# 5. train the model on the dataset using teacher forcing and a custom loss function - the loss function is a combination of cross entropy loss and a custom loss function that penalizes the model for creating the right structure.
# 6. evaluate the model on the test set.
# 7. save the model for future use.

In [5]:
# create a torch dataset from the html and json files
from torch.utils.data import Dataset
import os
import json
import torch
from bs4 import BeautifulSoup
from charactertokenizer import CharacterTokenizer

In [6]:
def remove_attrs(soup):
    for tag in soup.find_all(True):
        tag.attrs = {}
    return soup

In [7]:
# removing trailing and leading whitespaces from tag.strings for all html data
def remove_whitespace(soup):
    for tag in soup.find_all(True):
        if tag.string is None:
            continue
        tag.string = tag.string.strip()
    return soup

In [8]:
# load all html a store it in memory to save time in io operations
html_data = []
html_str_data = []
for html_file in os.listdir('./beaconcure_data/tables'):
    with open(f'./beaconcure_data/tables/{html_file}') as f:
        soup = BeautifulSoup(f, 'html.parser')
        soup = remove_attrs(soup)
        soup = remove_whitespace(soup)
        html_data.append(soup)
        #### TODO: remove the newlines between tags in the html files but not from the string data, e.g. from the soup object.
        html_str_data.append(str(soup).replace(">\n<", "><"))

In [9]:
# building a tokenizer for the html data, each tag is a token and the characters inside the tags are also tokens
# get a set of all tags in the html files
html_regular_tokens = set()
html_special_tokens = set()
for html_file in html_data:
    # add all tags to the set
    for tag in html_file.find_all(True):
        html_special_tokens.add("<{tag_name}>".format(tag_name = tag.name))
        html_special_tokens.add("</{tag_name}>".format(tag_name = tag.name))
    # add all characters to the set
    for char in html_file.get_text():
        html_regular_tokens.add(char)

In [10]:
html_regular_tokens

{'\n',
 ' ',
 '%',
 '&',
 "'",
 '(',
 ')',
 ',',
 '-',
 '.',
 '/',
 '0',
 '1',
 '2',
 '3',
 '4',
 '5',
 '6',
 '7',
 '8',
 '9',
 ':',
 'A',
 'B',
 'C',
 'D',
 'E',
 'F',
 'G',
 'H',
 'I',
 'J',
 'K',
 'L',
 'M',
 'N',
 'O',
 'P',
 'Q',
 'R',
 'S',
 'T',
 'U',
 'V',
 'W',
 'X',
 'Y',
 'Z',
 'a',
 'b',
 'c',
 'd',
 'e',
 'f',
 'g',
 'h',
 'i',
 'j',
 'k',
 'l',
 'm',
 'n',
 'o',
 'p',
 'q',
 'r',
 's',
 't',
 'u',
 'v',
 'w',
 'x',
 'y',
 'z'}

In [11]:
# create a tokenizer for the html data
html_tokenizer = CharacterTokenizer(html_regular_tokens, html_special_tokens, 10000, padding=True)



In [12]:
html_str_data[0]

'<table><caption>Table 59.99.9.62 Loss adjuster, chartered</caption><thead><tr><th></th><th>Daniel Brown</th><th>Shane Barnes DDS</th><th>Nicole Carpenter</th><th>Kristin Duarte</th></tr><th></th><th>programmer</th><th>Carpenter</th><th>singer</th><th>actor</th>\n\n</thead><tbody><tr><td>Roberts LLC</td><td>1060</td><td>37</td><td>1593</td><td>1364</td></tr></tbody><tfoot>modified: 5Feb2013</tfoot><tfoot>Creation: 3Feb2013 Chad</tfoot></table>\n'

In [13]:
html_tokenizer.decode(html_tokenizer(html_str_data[0])['input_ids'])

'[CLS]<table><caption>Table 59.99.9.62 Loss adjuster, chartered</caption><thead><tr><th></th><th>Daniel Brown</th><th>Shane Barnes DDS</th><th>Nicole Carpenter</th><th>Kristin Duarte</th></tr><th></th><th>programmer</th><th>Carpenter</th><th>singer</th><th>actor</th>\n\n</thead><tbody><tr><td>Roberts LLC</td><td>1060</td><td>37</td><td>1593</td><td>1364</td></tr></tbody><tfoot>modified: 5Feb2013</tfoot><tfoot>Creation: 3Feb2013 Chad</tfoot></table>\n[SEP]'

In [14]:
import json

class CustomJSONEncoder(json.JSONEncoder):
    def encode(self, obj):
        def custom_format(value):
            if isinstance(value, dict):
                items = [f'[{json.dumps(k)}][:]{custom_format(v)}' for k, v in value.items()]
                return f'[{{]{"[,]".join(items)}[}}]'
            elif isinstance(value, list):
                items = [custom_format(v) for v in value]
                return f'[[]{"[,]".join(items)}[]]'
            else:
                return json.dumps(value)
        
        return custom_format(obj)

In [15]:
json_data = []
json_str_data = []
for json_file in os.listdir('./beaconcure_data/metadata'):
    with open(f'./beaconcure_data/metadata/{json_file}') as f:
        parsed_json = json.load(f)
        json_data.append(parsed_json)
        json_str_data.append(json.dumps(parsed_json, cls=CustomJSONEncoder))

In [16]:
def get_keys(dictionary):
    keys = set()
    if isinstance(dictionary, list):
        for item in dictionary:
            keys.update(get_keys(item))
    elif isinstance(dictionary, dict):
        for key in dictionary:
            keys.add(key)
            keys.update(get_keys(dictionary[key]))
    return keys

In [17]:
def get_values(dictionary):
    values = set()
    if isinstance(dictionary, list):
        for item in dictionary:
            values.update(get_values(item))
    elif isinstance(dictionary, dict):
        for key, value in dictionary.items():
            # values.add(value)
            values.update(get_values(dictionary[key]))
    else:
        for value in dictionary:
            values.add(value)
    return values

In [18]:
# building a tokenizer for the html data, each tag is a token and the characters inside the tags are also tokens
# get a set of all tags in the html files
json_regular_tokens = set()
json_special_tokens = set()
json_special_tokens.add("[{]")
json_special_tokens.add("[}]")
json_special_tokens.add("[:]")
json_special_tokens.add("[,]")
json_special_tokens.add("[[]")
json_special_tokens.add("[]]")
json_regular_tokens.add("\"")
json_regular_tokens.add("\\")
for json_file in json_data:
    # add all tags to the set
    for key in get_keys(json_file):
        json_special_tokens.add(f"[\"{key}\"]")
    # add all characters to the set
    json_regular_tokens.update(get_values(json_file))

In [19]:
json_regular_tokens

{'\n',
 ' ',
 '"',
 '%',
 '&',
 "'",
 '(',
 ')',
 ',',
 '-',
 '.',
 '/',
 '0',
 '1',
 '2',
 '3',
 '4',
 '5',
 '6',
 '7',
 '8',
 '9',
 ':',
 'A',
 'B',
 'C',
 'D',
 'E',
 'F',
 'G',
 'H',
 'I',
 'J',
 'K',
 'L',
 'M',
 'N',
 'O',
 'P',
 'Q',
 'R',
 'S',
 'T',
 'U',
 'V',
 'W',
 'X',
 'Y',
 'Z',
 '\\',
 'a',
 'b',
 'c',
 'd',
 'e',
 'f',
 'g',
 'h',
 'i',
 'j',
 'k',
 'l',
 'm',
 'n',
 'o',
 'p',
 'q',
 'r',
 's',
 't',
 'u',
 'v',
 'w',
 'x',
 'y',
 'z'}

In [20]:
json_special_tokens

{'["body"]',
 '["col"]',
 '["content"]',
 '["footer"]',
 '["header"]',
 '["headers"]',
 '["row"]',
 '["table_creation_date:"]',
 '["table_id"]',
 '["text"]',
 '[,]',
 '[:]',
 '[[]',
 '[]]',
 '[{]',
 '[}]'}

In [21]:
# create a tokenizer for the json data
json_tokenizer = CharacterTokenizer(json_regular_tokens, json_special_tokens, 10000)

In [22]:
json_str_data[0]

'[{]["body"][:][{]["content"][:][[]"1060"[,]"37"[,]"1593"[,]"1364"[]][,]["headers"][:][{]["col"][:][[]"Roberts LLC"[]][,]["row"][:][[]"Daniel Brown"[,]"Shane Barnes DDS"[,]"Nicole Carpenter"[,]"Kristin Duarte"[,]"programmer"[,]"Carpenter"[,]"singer"[,]"actor"[]][}][}][,]["footer"][:][{]["table_creation_date:"][:]"3Feb2013"[,]["text"][:]"modified: 5Feb2013\\nCreation: 3Feb2013 Chad"[}][,]["header"][:][{]["table_id"][:]"59.99.9.62"[,]["text"][:]"Table 59.99.9.62 Loss adjuster, chartered"[}][}]'

In [23]:
json_tokenizer.decode(json_tokenizer(json_str_data[0])['input_ids'])[5:-5]

'[{]["body"][:][{]["content"][:][[]"1060"[,]"37"[,]"1593"[,]"1364"[]][,]["headers"][:][{]["col"][:][[]"Roberts LLC"[]][,]["row"][:][[]"Daniel Brown"[,]"Shane Barnes DDS"[,]"Nicole Carpenter"[,]"Kristin Duarte"[,]"programmer"[,]"Carpenter"[,]"singer"[,]"actor"[]][}][}][,]["footer"][:][{]["table_creation_date:"][:]"3Feb2013"[,]["text"][:]"modified: 5Feb2013\\nCreation: 3Feb2013 Chad"[}][,]["header"][:][{]["table_id"][:]"59.99.9.62"[,]["text"][:]"Table 59.99.9.62 Loss adjuster, chartered"[}][}]'

In [24]:
class BeaconCureDataset(Dataset):
    def __init__(self, html_data, json_data, html_tokenizer, json_tokenizer):
        self.html_data = [html_tokenizer.encode(html_str) for html_str in html_data]
        self.json_data = [json_tokenizer.encode(json_str) for json_str in json_data]

    def __len__(self):
        return len(self.html_data)

    def __getitem__(self, idx):
        return torch.LongTensor(self.html_data[idx]), torch.LongTensor(self.json_data[idx])

In [25]:
def collate_fn(batch, PAD_TOKEN_HTML, PAD_TOKEN_JSON):
    src_batch, tgt_batch = list(zip(*batch))
    src_batch = pad_sequence(src_batch, padding_value=PAD_TOKEN_HTML)
    tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_TOKEN_JSON)
    return src_batch, tgt_batch

In [26]:
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from functools import partial
PAD_TOKEN_HTML = html_tokenizer.get_vocab()['[PAD]']
PAD_TOKEN_JSON = json_tokenizer.get_vocab()['[PAD]']

collate_fn_partial = partial(collate_fn, PAD_TOKEN_HTML = PAD_TOKEN_HTML, PAD_TOKEN_JSON = PAD_TOKEN_JSON)
bc_dataset = BeaconCureDataset(html_str_data, json_str_data, html_tokenizer, json_tokenizer)
# train_dataloader = DataLoader(bc_dataset, batch_size=BATCH_SIZE, collate_fn=collate_fn_partial)


In [27]:
len(bc_dataset[0][0])

246

In [28]:
from torch import Tensor
import torch
import torch.nn as nn
from torch.nn import Transformer
import math
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# helper Module that adds positional encoding to the token embedding to introduce a notion of word order.
class PositionalEncoding(nn.Module):
    def __init__(self,
                 emb_size: int,
                 dropout: float,
                 maxlen: int = 5000):
        super(PositionalEncoding, self).__init__()
        den = torch.exp(- torch.arange(0, emb_size, 2)* math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('pos_embedding', pos_embedding)

    def forward(self, token_embedding: Tensor):
        return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])

# helper Module to convert tensor of input indices into corresponding tensor of token embeddings
class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, emb_size):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.emb_size = emb_size

    def forward(self, tokens: Tensor):
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)

# Seq2Seq Network
class Seq2SeqTransformer(nn.Module):
    def __init__(self,
                 num_encoder_layers: int,
                 num_decoder_layers: int,
                 emb_size: int,
                 nhead: int,
                 src_vocab_size: int,
                 tgt_vocab_size: int,
                 dim_feedforward: int = 512,
                 dropout: float = 0.1):
        super(Seq2SeqTransformer, self).__init__()
        self.transformer = Transformer(d_model=emb_size,
                                       nhead=nhead,
                                       num_encoder_layers=num_encoder_layers,
                                       num_decoder_layers=num_decoder_layers,
                                       dim_feedforward=dim_feedforward,
                                       dropout=dropout)
        self.generator = nn.Linear(emb_size, tgt_vocab_size)
        self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
        self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
        self.positional_encoding = PositionalEncoding(
            emb_size, dropout=dropout)

    def forward(self,
                src: Tensor,
                trg: Tensor,
                src_mask: Tensor,
                tgt_mask: Tensor,
                src_padding_mask: Tensor,
                tgt_padding_mask: Tensor,
                memory_key_padding_mask: Tensor):
        src_emb = self.positional_encoding(self.src_tok_emb(src))
        tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))
        outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None,
                                src_padding_mask, tgt_padding_mask, memory_key_padding_mask)
        return self.generator(outs)

    def encode(self, src: Tensor, src_mask: Tensor):
        return self.transformer.encoder(self.positional_encoding(
                            self.src_tok_emb(src)), src_mask)

    def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
        return self.transformer.decoder(self.positional_encoding(
                          self.tgt_tok_emb(tgt)), memory,
                          tgt_mask)

In [51]:
def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones((sz, sz), device=DEVICE)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask


def create_mask(src, tgt):
    src_seq_len = src.shape[0]
    tgt_seq_len = tgt.shape[0]

    # tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
    src_mask = torch.zeros((src_seq_len, src_seq_len),device=DEVICE).type(torch.bool)

    src_padding_mask = (src == PAD_TOKEN_HTML).transpose(0, 1)
    tgt_padding_mask = (tgt == PAD_TOKEN_JSON).transpose(0, 1)
    return src_mask, src_padding_mask, tgt_padding_mask

In [52]:

len(json_tokenizer)

99

In [47]:
torch.manual_seed(0)

SRC_VOCAB_SIZE = len(html_tokenizer)
TGT_VOCAB_SIZE = len(json_tokenizer)
EMB_SIZE = 32
NHEAD = 4
FFN_HID_DIM = 64
BATCH_SIZE = 128
NUM_ENCODER_LAYERS = 3
NUM_DECODER_LAYERS = 3
torch.cuda.empty_cache()

train_dataloader = DataLoader(bc_dataset, batch_size=BATCH_SIZE, collate_fn=collate_fn_partial)

transformer = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE,
                                 NHEAD, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, FFN_HID_DIM)

for p in transformer.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

transformer = transformer.to(DEVICE)

loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_TOKEN_JSON)

optimizer = torch.optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

In [48]:
def train_epoch(model, optimizer):
    model.train()
    losses = 0

    for i, (src, tgt) in enumerate(train_dataloader):
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)

        tgt_input = tgt[:-1, :]

        src_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)
        tgt_mask = model.transformer.generate_square_subsequent_mask(tgt_input.size(0)).to(DEVICE)
        logits = model(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask)

        optimizer.zero_grad()

        tgt_out = tgt[1:, :]
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        loss.backward()

        optimizer.step()
        print("Batch: {0}, Loss: {1}".format(i, loss.item()))
        losses += loss.item()

    return losses / len(list(train_dataloader))

In [49]:
# def evaluate(model):
#     model.eval()
#     losses = 0
#     val_dataloader = DataLoader(val_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)
# 
#     for src, tgt in val_dataloader:
#         src = src.to(DEVICE)
#         tgt = tgt.to(DEVICE)
# 
#         tgt_input = tgt[:-1, :]
# 
#         src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)
# 
#         logits = model(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask)
# 
#         tgt_out = tgt[1:, :]
#         loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
#         losses += loss.item()
# 
#     return losses / len(list(val_dataloader))

In [50]:
from timeit import default_timer as timer
NUM_EPOCHS = 18

for epoch in range(1, NUM_EPOCHS+1):
    start_time = timer()
    train_loss = train_epoch(transformer, optimizer)
    end_time = timer()
    # val_loss = evaluate(transformer)
    val_loss = 0
    print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}, "f"Epoch time = {(end_time - start_time):.3f}s"))

Batch: 0, Loss: 4.841628074645996
Batch: 1, Loss: 4.818394660949707
Batch: 2, Loss: 4.791187763214111
Batch: 3, Loss: 4.758297443389893
Batch: 4, Loss: 4.752079010009766
Batch: 5, Loss: 4.723038196563721
Batch: 6, Loss: 4.694632530212402
Batch: 7, Loss: 4.673187732696533
Batch: 8, Loss: 4.654795169830322
Batch: 9, Loss: 4.6329264640808105
Batch: 10, Loss: 4.614378452301025
Batch: 11, Loss: 4.598561763763428
Batch: 12, Loss: 4.582469940185547
Batch: 13, Loss: 4.582179069519043
Batch: 14, Loss: 4.545294284820557
Batch: 15, Loss: 4.5380539894104
Batch: 16, Loss: 4.521754741668701
Batch: 17, Loss: 4.5250749588012695
Batch: 18, Loss: 4.501384735107422
Batch: 19, Loss: 4.498808860778809
Batch: 20, Loss: 4.479899883270264
Batch: 21, Loss: 4.461790561676025
Batch: 22, Loss: 4.460177898406982
Batch: 23, Loss: 4.45826530456543
Batch: 24, Loss: 4.438604831695557
Batch: 25, Loss: 4.42635440826416
Batch: 26, Loss: 4.41569185256958
Batch: 27, Loss: 4.420628547668457
Batch: 28, Loss: 4.41126871109008

KeyboardInterrupt: 