# Загрузка библиотек

In [None]:
import os
import json

import numpy as np

from glob import glob
from transformers import MarkupLMFeatureExtractor, MarkupLMProcessor, MarkupLMForTokenClassification
from bs4 import BeautifulSoup
from torch.utils.data import Dataset, random_split, DataLoader


# Configuration

In [None]:
batch_size = 50


id2label = {1: "BEGIN", 0: "OTHER"}
label2id = {"BEGIN": 1, "OTHER": 0}

# Загрузка данных

In [None]:
def load_from_folder(folder_path : str):
    '''
        This function loading all json files from folder.
        Each file contains dict with labels and its values.
        Each file must contains "html" label with its html code. 
        Each file must contains "xpaths" label with its labeled xpaths list. 
        
    '''
    extractor = MarkupLMFeatureExtractor()
    
    folder_path = os.path.abspath(folder_path)
    files_path = glob(os.path.join(folder_path, "*.json"))
    
    data = []
    
    for file_path in files_path:
        print(file_path)
        with open(file_path) as file:
            info = json.load(file)
            
        html = info["html"]
        labeled_xpaths = info["xpaths"]

        encoding = extractor(html)
            
        
        labels = []
        
        for xpath in encoding["xpaths"][0]:
            if xpath in labeled_xpaths:
                labels.append(1)
            else:
                labels.append(0)


        print(len(labels))
        print([_ for _ in labels if _ != 0])
        
        labels = [labels]
        print(len(encoding['nodes'][0]), len(encoding['xpaths'][0]), len(labels[0]))
        data.append({'nodes': encoding['nodes'],
                     'xpaths': encoding['xpaths'],
                     'node_labels': labels,
                     'html': html})
        
    return data
    

In [None]:
train_data = load_from_folder("test_dataset/train_part")
valid_data = load_from_folder("test_dataset/test_part")

In [None]:
idx = 0
for node, label in zip(valid_data[idx]['nodes'][0], valid_data[idx]['node_labels'][0]):
  if id2label[label] != 'OTHER':
    print(node, id2label[label])
     

# Инициалиация датасета

In [None]:
class MarkupLMDataset(Dataset):
    """Dataset for token classification with MarkupLM."""

    def __init__(self, data, processor=None):
        self.data = data
        self.processor = processor

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # first, get nodes, xpaths and node labels
        item = self.data[idx]
        nodes, xpaths, node_labels = item['nodes'], item['xpaths'], item['node_labels']

        # provide to processor
        encoding = self.processor(nodes=nodes, xpaths=xpaths, node_labels=node_labels, padding="max_length", truncation=True, return_tensors="pt")

        # remove batch dimension
        encoding = {k: v.squeeze() for k, v in encoding.items()}

        return encoding

In [None]:
processor = MarkupLMProcessor.from_pretrained("microsoft/markuplm-base", truncation = True)
processor.parse_html = False

train_set = MarkupLMDataset(data=train_data, processor=processor)
valid_set = MarkupLMDataset(data=valid_data, processor=processor)

In [None]:
example = valid_set[0]
for k,v in example.items():
  print(k,v.shape)

In [None]:
processor.decode(example['input_ids'])

In [None]:
for id, label in zip(example['input_ids'].tolist(), example['labels'].tolist()):
    # if label != -100:
    #     print(processor.decode([id]), label)
    if label == 1:
        print(processor.decode([id]), label)

In [None]:
train_dataloader = DataLoader(train_set, batch_size=3, shuffle=True)
valid_dataloader = DataLoader(valid_set, batch_size=3, shuffle=True)

In [None]:
model = MarkupLMForTokenClassification.from_pretrained("microsoft/markuplm-base", id2label=id2label, label2id=label2id)

In [None]:
label_list = ["B-" + x for x in list(id2label.values())]

In [None]:
import evaluate

# Metric
metric = evaluate.load("seqeval")

def get_labels(predictions, references):
    # Transform predictions and references tensos to numpy arrays
    if device.type == "cpu":
        y_pred = predictions.detach().clone().numpy()
        y_true = references.detach().clone().numpy()
    else:
        y_pred = predictions.detach().cpu().clone().numpy()
        y_true = references.detach().cpu().clone().numpy()

    # Remove ignored index (special tokens)
    true_predictions = [
        [label_list[p] for (p, l) in zip(pred, gold_label) if l != -100]
        for pred, gold_label in zip(y_pred, y_true)
    ]
    true_labels = [
        [label_list[l] for (p, l) in zip(pred, gold_label) if l != -100]
        for pred, gold_label in zip(y_pred, y_true)
    ]
    return true_predictions, true_labels

def compute_metrics(return_entity_level_metrics=True):
    results = metric.compute()
    if return_entity_level_metrics:
        # Unpack nested dictionaries
        final_results = {}
        for key, value in results.items():
            if isinstance(value, dict):
                for n, v in value.items():
                    final_results[f"{key}_{n}"] = v
            else:
                final_results[key] = value
        return final_results
    else:
        return {
            "precision": results["overall_precision"],
            "recall": results["overall_recall"],
            "f1": results["overall_f1"],
            "accuracy": results["overall_accuracy"],
        }

# TRAIN

In [None]:
import torch
from torch.optim import AdamW
from tqdm.auto import tqdm
from sklearn.metrics import classification_report, f1_score

optimizer = AdamW(model.parameters(), lr=2e-5)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model.to(device)

model.train()
print(device)
for epoch in range(10):
    for batch in tqdm(train_dataloader):
        # get the inputs;
        inputs = {k:v.to(device) for k,v in batch.items()}

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(**inputs)

        loss = outputs.loss
        loss.backward()
        optimizer.step()

        # print("Loss:", loss.item())

        predictions = outputs.logits.argmax(dim=-1)
        labels = batch["labels"]
        preds, refs = get_labels(predictions, labels)
        metric.add_batch(
            predictions=preds,
            references=refs,
        )

    eval_metric = compute_metrics()
    print(f"Epoch {epoch}:", eval_metric)
      


# TEST

In [None]:
from metrics import segmentation_metric

In [None]:
model.to(device)
model.eval()
print(device)

valid_metric = segmentation_metric()

for record in tqdm(valid_data):
    nodes = record['nodes']
    xpaths = record['xpaths']
    node_labels = record['node_labels']
    
    encoding = processor(nodes=nodes, xpaths=xpaths, node_labels=node_labels, return_offsets_mapping=True, return_tensors="pt").to(device)
    
    offset_mapping = encoding.pop("offset_mapping")
    labels = encoding.pop("labels")

    
    with torch.no_grad():
        outputs = model(**encoding)


    predictions = outputs.logits.argmax(dim=-1)
    xpaths = batch["xpaths"]
    
    valid_metric.add_result({"true_xpaths" : [xpath for idx, xpath in enumerate(xpaths) if node_labels[idx] == 1]})
    
print(valid_metric.get_metric())
    
