In [76]:
import numpy as np

In [146]:
from typing import Callable
def CKY_Algorithm(sequence, get_terminate_parses: Callable, beam_search_strategy: Callable = None, enable_slide_windows = False, window_size = -1):
    if (enable_slide_windows and window_size <= 0) or sequence == None or len(sequence) == 0:
        return None
        
    T = len(sequence)
    records = []
    windows_beginning_offset = 0
    for t in range(T):
        print("t = " + str(t))
        w = get_terminate_parses(sequence[t])
        
        current_record = [[]] if len(records) == 0 else [[]] * (len(records[-1]) + 1)
        current_record[-1] = w

        # When the windows size is large, we may utilize array to store records, rather than list. The array allow us adopting 
        # sliding windows more effectively when the window size is large. However, we currently utilize list to store records.
        # To do sliding window, we need first pop the first record (oldest record inside the window), and for each of rest records,
        # we need pop its first element, as the first element of each record relate to the "word" we'd like to forget.
        if enable_slide_windows:
            if len(records) == window_size:
                records.pop(0)
                for i in range(0, len(records)):
                    records[i].pop(0)
                windows_beginning_offset += 1
                
        records.append(current_record)

        for inner_record_index in range(0, len(current_record) - 1):
            current_record_length = len(current_record) 
            current_record[current_record_length - inner_record_index - 2] = []
            cell_data = []
            for split_index in range(current_record_length - inner_record_index - 2, t - windows_beginning_offset):
                # coordination = (index, t)
                cell1_coordination = (current_record_length - inner_record_index - 2, split_index)
                cell2_coordination = (split_index + 1, t - windows_beginning_offset)
                cell_data += merge_parses_in_two_cells(records[cell1_coordination[1]][cell1_coordination[0]],\
                                                       records[cell2_coordination[1]][cell2_coordination[0]],\
                                                       cell1_coordination, cell2_coordination)
            current_record[current_record_length - inner_record_index - 2] = beam_search_strategy(cell_data) if beam_search_strategy is not None else cell_data
    return records


def get_terminate_parses(w):
    return ["(" + str(w) + ", " + str(w) + ")"]

def merge_parses_in_two_cells(cell1, cell2, cell1_id = None, cell2_id = None):
    return [[cell1_id, cell2_id]]

def select_tops(cell1, beam_size = 5):
    return cell1[:beam_size]

records = CKY_Algorithm(range(0, 10), get_terminate_parses, select_tops, False)

t = 0
t = 1
t = 2
t = 3
t = 4
t = 5
t = 6
t = 7
t = 8
t = 9


In [127]:
records[-1][0] # (t, span_start_index)

[[(0, 0), (1, 9)],
 [(0, 1), (2, 9)],
 [(0, 2), (3, 9)],
 [(0, 3), (4, 9)],
 [(0, 4), (5, 9)]]

In [147]:
import torch
import torch.nn as nn
import heapq

In [None]:
class NN_CYK_Model(nn.Module):
    def __init__(self, args):
        self.args = args
        self.device = 'cpu'
        self.NT = args['NT']
        self.T = args['T']
        self.V = args['cnt_words']
        self.s_dim = args['s_dim']
        self.r = args['r_dim']
        self.word_emb_size = args['word_emb_size']
        self.enable_slide_windows = False if 'enable_slide_windows' not in args else args['enable_slide_windows']
        self.window_size = False if 'window_size' not in args else args['window_size']
        self.word_embedding = args['word_embedding']
        self.grammar_unaries = args['grammar_unaries']
        self.grammar_preterminates = args['grammar_preterminates']
        self.grammar_double_nonterminates = args['grammar_double_nonterminates']
        self.grammar_starts = args['grammar_starts']

        self.merge_model = None # A -> BC
        self.terminate_feature_generation_model = None # D -> E
        self.preterminate_feature_generation_model = None # E -> w
        
        # check parameters
        if (self.enable_slide_windows and self.window_size <= 0) or word == None
            return None
        reset_global_context()
        
    def reset_global_context(self):
        self.records = []
        self.windows_beginning_offset = 0

    # each cell with form (possibility, feature_in_tensor, root_symbol_index)
    def forward(self, word): # input a word at time t.
        # generate the analysis of span [t, t]
        w = self.get_terminate_parses(word)

        # create cells for current time t.
        current_record = [[]] if len(records) == 0 else [[]] * (len(records[-1]) + 1)

        # fill the analysis of span [t, t] to the last cell of current time t.
        current_record[-1] = w

        # Sliding window control the maximum size of a record at time t.
        # When the windows size is large, we may utilize array to store records, rather than list. The array allow us adopting 
        # sliding windows more effectively when the window size is large. However, we currently utilize list to store records.
        # To do sliding window, we need first pop the first record (oldest record inside the window), and for each of rest records,
        # we need pop its first element, as the first element of each record relate to the "word" we'd like to forget.
        if enable_slide_windows:
            while len(self.records) >= self.window_size:
                self.records.pop(0)
                for i in range(0, len(self.records)):
                    self.records[i].pop(0)
                windows_beginning_offset += 1

        # Add record at current time to the global record set.
        records.append(current_record)
    
        for inner_record_index in range(0, len(current_record) - 1):
            current_record_length = len(current_record) 
            current_record[current_record_length - inner_record_index - 2] = []
            cell_data = []
            for split_index in range(current_record_length - inner_record_index - 2, t - windows_beginning_offset):
                # coordination = (index, t)
                cell1_coordination = (current_record_length - inner_record_index - 2, split_index)
                cell2_coordination = (split_index + 1, t - windows_beginning_offset)
                cell_data += self.merge_parses_in_two_cells(records[cell1_coordination[1]][cell1_coordination[0]],\
                                                           records[cell2_coordination[1]][cell2_coordination[0]],\
                                                           cell1_coordination, cell2_coordination)
                current_record[current_record_length - inner_record_index - 2] = beam_search_strategy(cell_data) if beam_search_strategy is not None else cell_data
        return heapq.nlargest(1, records[-1][0])[0][1]
        
        def get_unary_grammar_id(self, w):
            return self.word_unary_grammar_id_mapping[w]

        def get_grammar_for_generation(self, parse_i, parse_j):
            return self.double_nonterminates_grammar_id_mapping["%s#%s" % (str(parse_i[0]), str(parse_j[0]))]
        
        def get_terminate_parses(self, w):
            unary_grammar_id = self.get_unary_grammar_id(w)
            preterminate_grammar_id = self.preterminate_grammar_id(unary_grammar_id)
            feature1 = self.terminate_feature_generation_model(self.word_embedding[w], self.grammar_unaries[unary_grammar_id])
            feature2 = self.preterminate_feature_generation_model(feature1, self.grammar_preterminates[preterminate_grammar_id])
            return [[1.0, feature1, unary_grammar_id], [1.0, feature2, preterminate_grammar_id]]
        
        def merge_parses_in_two_cells(self, cell1, cell2, cell1_id = None, cell2_id = None):
            result = []
            for parse_i in range(cell1):
                for parse_j in range(cell2):
                    p_1 = parse_i[0]
                    p_2 = parse_j[1]
                    g = self.get_grammar(parse_i, parse_j) # [begin_id, left_symbol_id, right_symbol_id, possibility]
                    p = p_1 * p_2 * g[-1]
                    feature = self.merge_model(parse_i[1], parse_j[2], g)
                    result.append([p, feature, g[0]])
            return result
        
        def select_tops(self, cell, beam_size = 5):
            return heapq.nlargest(beam_size, cell)
        
records = CKY_Algorithm(range(0, 10), get_terminate_parses, select_tops, False)

In [139]:
[(0.999, [23124]), (9.45, [2124,33])].max()

AttributeError: 'list' object has no attribute 'max'

In [145]:
heapq.nlargest(1, [[1, [1,2,3]]])[0]

[1, [1, 2, 3]]