In [None]:
import json
!pip install benepar
import benepar

benepar.download('benepar_en3')

import spacy
!python -m spacy download en_core_web_sm
from collections import Counter
from nltk import Tree

In [None]:
# Load spaCy's English model
nlp = spacy.load("en_core_web_sm")

# Add benepar to the pipeline

nlp.add_pipe("benepar", config={"model": "benepar_en3"})

In [None]:
model = "bloom"     # "mistral" / "bloom" / "llama"

toxic_prompts = [json.loads(x)["prompt"] for x in open(f"results/{model}/most_toxic.jsonl").readlines()]

nontoxic_prompts = [json.loads(x)["prompt"] for x in open(f"results/{model}/least_toxic.jsonl").readlines()]


def extract_constituency_trees(sentences):
    trees = []
    for sentence in sentences:
        doc = nlp(sentence)
        for sent in doc.sents:
            tree = Tree.fromstring(sent._.parse_string)
            trees.append(tree)
    return trees

toxic_trees = extract_constituency_trees(toxic_prompts)
nontoxic_trees = extract_constituency_trees(nontoxic_prompts)

In [None]:
def extract_non_leaf_phrases(tree):
    """ Recursively extract non-leaf phrase types from a constituency tree. """
    phrases = []
    if isinstance(tree, Tree) and tree.height() > 2:  # Exclude leaf nodes
        phrase_structure = (tree.label(), tuple(extract_non_leaf_phrases(subtree) for subtree in tree if isinstance(subtree, Tree)))
        phrases.append(phrase_structure)
        for subtree in tree:
            phrases.extend(extract_non_leaf_phrases(subtree))
    return phrases

# Extract and count non-leaf phrase type subtrees
def extract_and_count_non_leaf_subtrees(trees):
    subtree_counter = Counter()
    for tree in trees:
        non_leaf_subtrees = extract_non_leaf_phrases(tree)
        non_leaf_subtree_strings = [str(subtree) for subtree in non_leaf_subtrees]
        subtree_counter.update(non_leaf_subtree_strings)
    return subtree_counter

non_leaf_subtree_patterns = extract_and_count_non_leaf_subtrees(toxic_trees)
# non_leaf_subtree_patterns = extract_and_count_non_leaf_subtrees(nontoxic_trees)

# Print the most common non-leaf subtrees by phrase types, ordered by count
print("Most common non-leaf subtrees by phrase types:")
for subtree, count in non_leaf_subtree_patterns.most_common(5):
    print(f"{subtree}: {count}")
