In [1]:
import spacy
import json
import matplotlib.pyplot as plt
from collections import Counter
from tqdm import tqdm

In [2]:
def calculate_height(token):
    if not list(token.children):
        return 0
    else:
        return 1 + max(calculate_height(child) for child in token.children)
    
def count_subtree_nodes(token):
    count = 1
    for child in token.children:
        count += count_subtree_nodes(child)
    return count

In [3]:
def depth_analyze(file_name):
    all_depth = list()
    root_depth = list()
    root_nodes = list()
    nlp = spacy.load("en_core_web_sm")
    with open(f"perturbed_text/{file_name}", "r", encoding="utf-8") as f:
        for line in tqdm(f.readlines()):
            line = json.loads(line)
            text = line['text']
            doc = nlp(text)
            for token in doc:
                depth = 0
                current_token = token
                while current_token.head != current_token:
                    depth += 1
                    current_token = current_token.head
                all_depth.append(depth)
                if token.dep_ == "ROOT":
                    root_depth.append(calculate_height(token))
                    root_nodes.append(count_subtree_nodes(token))
    avg_all_depth = sum(all_depth) / len(all_depth)
    avg_root_depth = sum(root_depth) / len(root_depth)
    avg_root_nodes = sum(root_nodes) / len(root_nodes)
    print(f"{file_name}\tavg_all_depth: {avg_all_depth}\tavg_root_depth: {avg_root_depth}\tavg_root_nodes:{avg_root_nodes}")
    with open(f"output/depth_analyze_result.txt", "a", encoding='utf=8') as w:
        w.write(f"{file_name}\t{avg_all_depth}\t{avg_root_depth}\t{avg_root_nodes}\n")
    
    plt.hist(all_depth, bins=(max(all_depth)- min(all_depth)))
    plt.xlabel('Depth of Nodes')
    plt.ylabel('Frequency')
    png_name = file_name.replace(".json", "_all_depth_.png")
    plt.savefig(f"output/{png_name}")
    plt.close()

    plt.hist(root_nodes, bins=(max(root_nodes)- min(root_nodes)))
    plt.xlabel('Number of Nodes')
    plt.ylabel('Frequency')
    png_name = file_name.replace(".json", "_root_nodes.png")
    plt.savefig(f"output/{png_name}")
    plt.close()

    plt.hist(root_depth, bins=(max(root_depth)- min(root_depth)))
    plt.xlabel('Depth of Root Nodes')
    plt.ylabel('Frequency')
    png_name = file_name.replace(".json", "_root_depth.png")
    plt.savefig(f"output/{png_name}")
    plt.close()
    

In [4]:
file_list = ["gpt3.5_mixed_test_llama3_8b_instruct_rewrite.json",
             ]
for file in file_list:
    depth_analyze(file)

100%|██████████| 1000/1000 [01:28<00:00, 11.36it/s]


3.1282040604749137 5.762543360869721 24.905088867431964
