### 0. Import libraries and implement functions

In [1]:
from pylatexenc.latexwalker import LatexWalker, LatexEnvironmentNode, LatexCharsNode, LatexCommentNode,\
                                    LatexGroupNode, LatexMathNode, LatexMacroNode, LatexSpecialsNode
import sys, re, os

In [2]:
def find_tex_files(dir, max_depth):
    if max_depth == 0 or not os.path.isdir(dir):
        return []

    fs = os.listdir(dir)

    if dir != ".":
        fs = [dir + "/" + f for f in fs]

    sub = [f for f in fs if os.path.isdir(f)]
    fs = [f for f in fs if os.path.isfile(f) and f.endswith(".tex")]

    for d in sub:
        fs.extend(find_tex_files(d, max_depth - 1))
    return fs

In [3]:
def find_begin_document(files):
    for file in files:
        with open(file, "r") as f:
            text = f.read()
            if r"\begin{document}" in text:
                return file
    return None

In [4]:
def find_include(tex):
    pattern = r"\\(?:include|input)\{([\w\d\/_]+)\}"
    lines = tex.split("\n")
    files = []
    for i in range(len(lines)):
        line = lines[i]

        percent_idx = 0
        while True:
            percent_idx = line.find("%", percent_idx)
            if percent_idx == -1:
                break

            if percent_idx == 0 or (percent_idx > 0 and line[percent_idx - 1] != '\\'):
                break
            percent_idx += 1

        if percent_idx == 0:
            continue
        elif percent_idx != -1:
            line = line[:percent_idx - 1]

        captured = re.findall(pattern, line)
        if len(captured) > 0:
            files.extend(captured)

    return list(set(files))

In [5]:
def get_latex_nodes(fp):
    with open(fp, "r") as f:
        text = f.read()

    newcommand_pattern = r"^\\newcommand\{([^}]+)\}(?:\[[^]]+\])?\{(.+)\}$"
    newcommands = re.findall(newcommand_pattern, text, flags=re.M)
    # print(newcommands)
    for (short_hand, cmd) in newcommands:
        short_hand = short_hand.replace("\\", "\\\\")
        cmd = cmd.replace("\\", "\\\\")
        text = re.sub(short_hand, cmd, text)

    text = re.sub(newcommand_pattern, "", text)

    w = LatexWalker(text)
    nodes, _, _ = w.get_latex_nodes()
    return nodes

In [6]:
SENTENCE_PATTERN = re.compile(r"(?<!\b[A-Z])(?<![Ee][Tt] [Aa][Ll])\.\s+(?=[A-Z])")

def split_sentences(text, level):
    sentences = SENTENCE_PATTERN.split(text)
    sentences = [(sentence, level) for sentence in sentences]

    return sentences

In [7]:
LEVELS = {
    "document": 0,
    "abstract": 1,
    "section": 1,
    "subsection": 2,
    "subsubsection": 3,
    "paragraph": 4,
    "subparagraph": 5,
    "itemize": 6,
    "item": 7,
    "leaf": 8,
}

In [17]:
def hierarchy_nodes(nodes, dir, append_trailing=False):
    text = ""
    tokens = []
    refs = []

    for node in nodes:
        if node == None or isinstance(node, LatexCommentNode):
            continue

        if isinstance(node, LatexCharsNode):
            chars = node.chars
            text += chars
        elif isinstance(node, LatexMathNode):
            text += node.latex_verbatim()
        elif isinstance(node, LatexGroupNode):
            text += node.latex_verbatim()
            # tokens.extend(node.nodelist)
        elif isinstance(node, LatexEnvironmentNode):
            env_name = node.environmentname.lower()
            if env_name in ["figure", "figure*", "equation", "equation*", "align", "align*", "table", "remark", "remark*"]:
                tokens.extend(split_sentences(text, LEVELS["leaf"]))
                tokens.append((node.latex_verbatim(), LEVELS["leaf"]))
                text = ""
            elif env_name == "itemize":
                tokens.extend(split_sentences(text, LEVELS["leaf"]))
                tokens.append((env_name, LEVELS[env_name]))
                latex = node.latex_verbatim()
                # pattern = r"\\(begin|end)\{" + env_name + r"\}" + r"(\[[^]]+\])?"
                pattern = r"\\(begin|end)\{itemize\}(\[[^]]+\])?"
                latex = re.sub(pattern, "", latex, flags=re.IGNORECASE).strip()

                items = re.split(r"\\item", latex, flags=re.IGNORECASE)
                for item in items:
                    item = item.strip()
                    if len(item) == 0:
                        continue

                    tokens.append(("item", LEVELS["item"]))
                    tokens.extend(split_sentences(item, LEVELS["leaf"]))

            elif env_name in "document":
                tokens.append((env_name, LEVELS[env_name]))
                sub_tokens, sub_refs = hierarchy_nodes(node.nodelist, dir, True)
                tokens.extend(sub_tokens)
                refs.extend(sub_refs)
            elif env_name in "abstract":
                tokens.append((env_name, LEVELS[env_name]))
                latex = node.latex_verbatim()
                latex = re.sub(r"\\(begin|end)\{abstract\}", "", latex, flags=re.IGNORECASE).strip()
                tokens.extend(split_sentences(latex, LEVELS["leaf"]))
            else:
                # print(env_name, "[Environment]")
                sub_tokens, sub_refs = hierarchy_nodes(node.nodelist, dir, True)
                tokens.extend(sub_tokens)
                refs.extend(sub_refs)
        elif isinstance(node, LatexMacroNode):
            if node.macroname in ["input", "include"]:
                if len(node.nodeargd.argnlist) != 1 or len(node.nodeargd.argnlist[0].nodelist) != 1 \
                        or not isinstance(node.nodeargd.argnlist[0].nodelist[0], LatexCharsNode):
                    print("\t"*tabs, "?Empty?", "[\\include]")
                    exit(1)
                else:
                    tokens.extend(split_sentences(text, LEVELS["leaf"]))
                    text = ""

                    fp = os.path.join(dir, node.nodeargd.argnlist[0].nodelist[0].chars)
                    dependencies_nodes = get_latex_nodes(fp)
                    print("Parse", fp)
                    sub_tokens, sub_refs = hierarchy_nodes(dependencies_nodes, dir, True)
                    tokens.extend(sub_tokens)
                    refs.extend(sub_refs)

                    text = ""
            elif node.macroname in ["section", "subsection", "subsubsection", "paragraph", "subparagraph"]:
                tokens.extend(split_sentences(text, LEVELS["leaf"]))
                tokens.append((node.latex_verbatim(), LEVELS[node.macroname]))
                text = ""
            elif node.macroname in ["label", "footnote"]:
                latex = node.latex_verbatim()
                pattern = fr"\\{node.macroname}" + "{[^}]+}"
                latex = re.sub(pattern, "", latex)
                text += latex
            elif node.macroname in ["cite", "citep", "citet"]:
                latex = node.latex_verbatim()
                latex = re.sub(r"\\" + node.macroname + r"\{", "", latex)
                latex = latex[:-2]
                refs.extend(latex.split(","))
            else:
                text += node.latex_verbatim()
                pass
        elif isinstance(node, LatexSpecialsNode):
            text += node.specials_chars

    if append_trailing and text != "":
        tokens.extend(split_sentences(text, LEVELS["leaf"]))

    return tokens, refs

In [25]:
def hierarchy_version(version_directory):
    main_tex_fp = os.path.join(version_directory, "main.tex")

    if not os.path.exists(main_tex_fp):
        depth1_tex_files = find_tex_files(version_directory, 1)
        main_tex_fp = find_begin_document(depth1_tex_files)
        if main_tex_fp == None:
            print("Not found main tex file")
            return None

    nodes = get_latex_nodes(main_tex_fp)
    nodes, refs = hierarchy_nodes(nodes, version_directory)
    nodes = [node for node in nodes if len(node[0]) > 0]
    refs = list(set(refs))

    ignore_idx = 0
    while ignore_idx < len(nodes) and nodes[ignore_idx] != ("document", 0):
        ignore_idx += 1
    nodes = nodes[ignore_idx:]

    node_stack = [0]
    node_hierarchy = {0: 0}

    for i in range(1, len(nodes)):
        last_node = nodes[node_stack[-1]]
        last_node_level = last_node[1]

        current_node = nodes[i]
        current_node_level = current_node[1]

        if current_node_level > last_node_level:
            node_hierarchy[i] = node_stack[-1]
            node_stack.append(i)
        elif current_node_level == last_node_level:
            node_hierarchy[i] = node_hierarchy[node_stack[-1]]
        else:
            while len(nodes) > 0 and current_node_level <= nodes[node_stack[-1]][1]:
                node_stack.pop()

            if len(nodes) == 0:
                node_hierarchy[i] = 0
            else:
                node_hierarchy[i] = node_stack[-1]

            node_stack.append(i)

    return nodes, node_hierarchy, refs


### 1. Testing

In [26]:
nodes, node_hierarchy, refs = hierarchy_version("../../23127247_milestone1/2210.16424/tex/2210.16424v1/")
print("Refs:", refs)
for node, parent in node_hierarchy.items():
    print(nodes[node], nodes[parent])

Parse ../../23127247_milestone1/2210.16424/tex/2210.16424v1/math_commands.tex
Refs: ['zhu2019deep', 'cao2015toward', 'bonawitz2021federate', 'wang2021field', 'liu2021federaser', 'thomee2016yfcc100', 'Berlekamp1968AlgebraicCT', 'chen2020breaking', 'pmlr-v119-guo20', 'gardner2014measurin', 'so2022lightsecagg', 'geiping2020invertin', 'pmlr-v119-guo20c', 'hartigan1979algorith', 'ghosh2020efficient', 'gardner2014measuring', 'gardner20143', 'wu2022federated', 'bourtoule2021machin', 'caldas2018lea', 'bourtoule2021machine', 'bonawitz2017practical', 'kedlaya2011fas', 'gan2017', 'vassilvitskii2006', 'blackard1999comparativ', 'gandikota2021vqsg', 'kissner2005privacy', 'deng2009imagene', 'hutter2018cance', 'han2018mapping', 'chen2022fundamenta', 'chung2022federate', 'kairouz2021advances', 'dennis2021heterogeneit', 'so2022lightsecag', 'mahajan2012plana', 'fredrikson2015model', 'dennis2021heterogeneity', 'bell2020secure', 'chien2018query', 'li2022secur', 'guha2003clustering', 'seo2012constan', 'acha