In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import config
import sys
import os

config.root_path = os.path.abspath(os.path.join(os.getcwd(), ".."))
sys.path.insert(0, config.root_path)

In [3]:
from db.dbv2 import Table, AugmentedTable, TrainTestTable
from src.dataset.utils import truncate_by_token, flatten, dedupe_list, truncate_string

### === Prepare Dataset ===

In [10]:
def get_data(dataset_type: str):
    table = Table(dataset_type)
    
    num_sentences = 100
    offset = 200
    max_segment_length = 99
    
    all_segments = table.get_all_segments()
    
    segments = [[y[1] for y in x][:max_segment_length] for x in all_segments]
    segments_labels = [
        [1 if i == 0 else 0 for i, y in enumerate(x)][:max_segment_length] for x in all_segments
    ]
    
    flattened_segments = flatten(segments)
    flattened_labels = flatten(segments_labels)
    
    segments_to_test = flattened_segments[offset:offset+num_sentences]
    labels_to_test = flattened_labels[offset:offset+num_sentences]

    return segments_to_test, labels_to_test

### === Testing ===

In [11]:
from src.determinor import Determinor
from nltk.metrics.segmentation import pk, windowdiff

In [12]:
determinor = Determinor()

In [14]:
# for dataset_type in ["choi_3_5", "choi_3_11", "choi_9_11", "choi_6_8"]:
for dataset_type in ["choi_3_5", "choi_3_11"]:
    segments, labels = get_data(dataset_type)

    print(f"evaluating {dataset_type}")
    predictions = determinor.query_batch_data(segments)

    preds = [0 if p == True else 1 for p in predictions]

    str_labels = ''.join([str(x) for x in labels[1:]])
    str_predictions = ''.join([str(x) for x in preds])
    print()
    print(f"L: {str_labels}")
    print(f"P: {str_predictions}")

    for k in [2,3,4,5,6,7]:
        print(f"k: {k}, pk: {pk(str_labels, str_predictions, k=k)}, wd: {windowdiff(str_labels, str_predictions, k=k)}")

Using dataset: choi_3_5
evaluating choi_3_5
..|.|||||||...|.||..|..|....||||||.|..|....|..|..||..|||||...|.||||...|||..|..|||...|||..|..|..|.|.
L: 001000100100001001001001000010000100001000010010001000100100010000100010000100001000010001001000010
P: 001011111110001011001001000011111101001000010010011001111100010111100011100100111000111001001001010
k: 2, pk: 0.24489795918367346, wd: 0.40816326530612246
k: 3, pk: 0.18556701030927836, wd: 0.5257731958762887
k: 4, pk: 0.08333333333333333, wd: 0.6354166666666666
k: 5, pk: 0.0, wd: 0.7263157894736842
k: 6, pk: 0.0, wd: 0.7978723404255319
k: 7, pk: 0.0, wd: 0.8387096774193549
Using dataset: choi_3_11
evaluating choi_3_11
......||......|...||..|..||.|.....|.|||||.....|..|....|.|.|..||.......||..|......|.||.|||....|....|
L: 000000010000001000000000010010000010000010000010010000000010001000100001001000000100000100000000001
P: 000000110000001000110010011010000010111110000010010000101010011000000011001000000101101110000100001
k: 2, pk: 0.265306122