In [1]:
!pip install treelib

Collecting treelib
  Downloading treelib-1.7.0-py3-none-any.whl.metadata (1.3 kB)
Downloading treelib-1.7.0-py3-none-any.whl (18 kB)
Installing collected packages: treelib
Successfully installed treelib-1.7.0


In [2]:
# Run if working locally
%load_ext autoreload
%autoreload 2

In [3]:
import sqlite3
from sqlite3 import Error
import pickle
import os, sys
import config

config.root_path = os.path.abspath(os.path.join(os.getcwd(), ".."))
sys.path.insert(0, config.root_path)

from src.dataset.dataset import RawData
from src.dataset.wikisection_preprocessing import (
    tokenize,
    clean_sentence,
    preprocess_text_segmentation,
    format_data_for_db_insertion,
)
from src.dataset.utils import truncate_by_token
from db.dbv2 import Table, AugmentedTable, TrainTestTable
import pprint


from utils.metrics import windowdiff, pk
from treelib import Node, Tree

In [4]:
dataset_type = "city"
table = Table(dataset_type)

Using dataset: wikisection_city


In [45]:
class CoherenceNode():
    def __init__(self, id, word, vector):
        self.id = id
        self.word = word
        self.vector = vector
        # self.level = level
    
    def __repr__(self):
        return f'Node({self.id}, \'{self.word}\', {self.vector})'

In [78]:
T = Tree()
G = Tree()

In [79]:
# (id, word_text, vector_representation)

segments = [
    [(0, "hello", 5), (1, "world", 6), (2, "earth", 7)],
    [(3, "school", 3), (4, "work", 5)],
    [(5, "assignments", 6), (6, "deadline", 4), (7, "lazy", 2), (8, "midterms", 5)],
    [(9, "gym", 4.5), (10, "math", 6), (11, "science", 6.5), (12, "world", 2)],
]

In [80]:
T.create_node("Harry", "harry")  # root node
T.create_node("Jane", "jane", parent="harry")
T.create_node("Bill", "bill", parent="harry")
G.create_node("Diane", "diane",)
G.create_node("Mary", "mary", parent="diane")
G.create_node("Mark", "mark", parent="diane")

Node(tag=Mark, identifier=mark, data=None)

In [82]:
print(T.show(stdout=False))
print(G.show(stdout=False))

Harry
├── Bill
└── Jane

Diane
├── Mark
└── Mary



In [83]:
T.paste("bill", G, deep=False)

In [84]:
print(T.show(stdout=False))
print(G.show(stdout=False))

Harry
├── Bill
│   └── Diane
│       ├── Mark
│       └── Mary
└── Jane

Diane
├── Mark
└── Mary



In [118]:
T.get_node(T.root).identifier

'harry'

In [177]:
# simulation

trees = []

id = 0
for segment in segments:
    all_roots = []
    for tree in trees:
        all_roots.append((tree, tree.get_node(tree.root)))
            
    for word in segment:
        print('all roots', all_roots)
        word_text, word_vector = word[1], word[2]
        new_node = CoherenceNode(id, word_text, word_vector)
        if len(trees) == 0:
            new_tree = Tree()
            new_tree.create_node(word_text, id, data=new_node)
            id += 1
            trees.append(new_tree)
            continue

        placed = False
        G = Tree()
        G.create_node(word_text, id, data=new_node)
        id += 1

        trees_to_remove = []
        for original_tree, root in all_roots:
            print(new_node, root)
            if new_node.vector * root.data.vector >= 20:
                print(f"strength between {G.get_node(G.root).tag} and {root.tag}")
                G.paste(G.get_node(G.root).identifier, original_tree, deep=False)
                print(f"testing: {[x.get_node(x.root).identifier for x in trees]}")
                print(f"current node identifier: {original_tree.get_node(original_tree.root).identifier}")
                print(f"removing {[t.get_node(t.root).tag for t in list(filter(lambda x: x.root == original_tree.get_node(original_tree.root).identifier, trees))]}")
                # remove the old tree from the list of trees since we create a new one
                # with the current node as the root
                trees = list(filter(lambda x: x.get_node(x.root).identifier == original_tree.get_node(original_tree.root).identifier, trees)) 
                id += 1
                placed = True
                # original_tree.create_node(word_text, id, parent=leaf, data=new_node)

        if G.root is not None:
            print("G is created")
            print(G.show(stdout=False))
            trees.append(G)
            placed = True
            # id += 1

        # if not placed:
        #     print("not placed")
        #     new_tree = Tree()
        #     new_tree.create_node(word_text, id, data=new_node)
        #     id += 1
        #     trees.append(new_tree)
            
        # node = CoherenceNode(word[0], word[1], word[2])
        # for n in prev_graph.nodes():
        #     if n.vector*node.vector >= 30:
        #         temp_graph.add_edge(node, n, weight=n.vector*node.vector)

all roots []
all roots []
G is created
world

all roots []
G is created
earth

all roots [(<treelib.tree.Tree object at 0x10844ddf0>, Node(tag=hello, identifier=0, data=Node(0, 'hello', 5))), (<treelib.tree.Tree object at 0x10844dfa0>, Node(tag=world, identifier=1, data=Node(1, 'world', 6))), (<treelib.tree.Tree object at 0x10844df70>, Node(tag=earth, identifier=2, data=Node(2, 'earth', 7)))]
Node(3, 'school', 3) Node(tag=hello, identifier=0, data=Node(0, 'hello', 5))
Node(3, 'school', 3) Node(tag=world, identifier=1, data=Node(1, 'world', 6))
Node(3, 'school', 3) Node(tag=earth, identifier=2, data=Node(2, 'earth', 7))
strength between school and earth
testing: [0, 1, 2]
current node identifier: 2
removing ['earth']
G is created
school
└── earth

all roots [(<treelib.tree.Tree object at 0x10844ddf0>, Node(tag=hello, identifier=0, data=Node(0, 'hello', 5))), (<treelib.tree.Tree object at 0x10844dfa0>, Node(tag=world, identifier=1, data=Node(1, 'world', 6))), (<treelib.tree.Tree object a

ValueError: Duplicated nodes ['0', '1', '2', '5'] exists.

In [178]:
node = trees[0].get_node(0)
node.data

Node(0, 'hello', 5)

In [179]:
for leaf in trees[0].leaves():
    print(leaf)

Node(tag=hello, identifier=0, data=Node(0, 'hello', 5))
Node(tag=world, identifier=1, data=Node(1, 'world', 6))
Node(tag=earth, identifier=2, data=Node(2, 'earth', 7))


In [180]:
trees

[<treelib.tree.Tree at 0x10844df40>]

In [181]:
# for reference
# segments = [
#     [(0, "hello", 5), (1, "world", 6), (2, "earth", 7)],
#     [(3, "school", 3), (4, "work", 5)],
#     [(5, "assignments", 6), (6, "deadline", 4), (7, "lazy", 2), (8, "midterms", 5)],
#     [(9, "gym", 4.5), (10, "math", 6), (11, "science", 6.5), (12, "world", 2)],
# ]

for tree in trees:
    print(tree.show(stdout=False))

work
├── earth
├── hello
└── world

