In [1]:
from sentence_transformers import SentenceTransformer, util
from transformers import AutoTokenizer, AutoModel
import torch.nn.functional as F
from lxml import html, etree
from requests_html import HTMLSession
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# model = SentenceTransformer('sentence-transformers/all-roberta-large-v1')
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')

In [7]:
def getNodes(code):
    res = {}
    t = html.fromstring(code)
    r = etree.ElementTree(t)
    t = t.find('body')
    

    def traverseNodes(node):
        res[r.getpath(node)] = node
        for i in node.getchildren():
            traverseNodes(i)


    traverseNodes(t)
    return res


def cleanNodes(nodes):
    res = {}
    for i in nodes:
        n = nodes[i]
        if n.tag is etree.Comment or n.tag is etree.ProcessingInstruction or n.text == None or n.text.strip() == "":
            pass
        else:
            res[i] = n.text

    return res

def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)


def embbed(nodes):
    # res = {}
    # for i in nodes:
    #     res[i] =  torch.tensor(model.encode([nodes[i]])[0])
    # return res

    res = {}
    for i in nodes:
        encoded_input = tokenizer(nodes[i], padding=True, truncation=True, return_tensors='pt')

        with torch.no_grad():
            model_output = model(**encoded_input)

        sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
        sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
    
        res[i] = sentence_embeddings

    return res
    



# APIs to call

In [3]:
def loadAsyncronousEmbeddings(code):
    nodes = getNodes(code)
    cleanedNodes = cleanNodes(nodes)
    embbedNodes = embbed(cleanedNodes)
    
    return embbedNodes


def getSimilarityMatch(code, embeddedNodes, matchString, uniqueIndentifier):
    encoded_input = tokenizer(matchString, padding=True, truncation=True, return_tensors='pt')
    with torch.no_grad():
        model_output = model(**encoded_input)

    compareEmbed = mean_pooling(model_output, encoded_input['attention_mask'])
    compareEmbed = F.normalize(compareEmbed, p=2, dim=1)

    res = {}
    for i in embeddedNodes:
        res[i] = util.pytorch_cos_sim(embeddedNodes[i], compareEmbed).tolist()[0][0]

    
    maxNodeXPath = max(res, key=res.get)
    tree = html.fromstring(code)
    codelems = tree.getroottree().xpath(maxNodeXPath)
    for i in codelems:
        i.attrib["class"] =  i.attrib["class"] + " " + "custom-text-suggestion-"+ str(uniqueIndentifier)

    return etree.tostring(tree)    
    


# Example API Usage

In [5]:
session = HTMLSession()
r = session.get("https://wix.com")
code = html.fromstring(r.text)

In [8]:
res = loadAsyncronousEmbeddings(r.text)
labelledcode  = getSimilarityMatch(r.text, res, "Create a website without limits", 1)

# Experiments

In [177]:
session = HTMLSession()
r = session.get("https://wix.com")
code = html.fromstring(r.text)

In [178]:
nodes = getNodes(r.text)
cleanedNodes = cleanNodes(nodes)
embeddedNodes = embbed(cleanedNodes)

In [130]:
compareString = "Create a website without limits"
# compareEmbed = torch.tensor(model.encode([compareString])[0])
encoded_input = tokenizer(compareString, padding=True, truncation=True, return_tensors='pt')

with torch.no_grad():
    model_output = model(**encoded_input)

compareEmbed = mean_pooling(model_output, encoded_input['attention_mask'])
compareEmbed = F.normalize(compareEmbed, p=2, dim=1)

In [139]:
res = {}
for i in embeddedNodes:
    res[i] = util.pytorch_cos_sim(embeddedNodes[i], compareEmbed).tolist()[0][0]

In [143]:
maxXPath = max(res, key=res.get)
nodes["/body/div[1]/div/div[3]/div/main/div/div/div/div[2]/div[1]/div/div/section[1]/div[2]/div/section[2]/div[2]/div/div[2]/div/div[2]/h1/span/span/span/span"].text

'Create a website without limits'