<a href="https://colab.research.google.com/github/Nahom32/CKY-Parser/blob/main/Copy_of_CKY_parser.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install nltk



In [None]:
!pip install spacy



In [None]:
from collections import defaultdict


class CKYParser:
    def __init__(self, grammar, k=3, start_symbol="S"):
        self.grammar = grammar
        self.k = k
        self.start = start_symbol

    def parse(self, sentence: str):
        words = sentence.split()
        n = len(words)

        chart = [[defaultdict(list) for _ in range(n + 1)] for _ in range(n)]

        for i, word in enumerate(words):
            if word not in self.grammar.lexical:
                raise ValueError(f"No lexical rule for word: '{word}'")

            for A, logp in self.grammar.lexical[word]:
                tree = (A, word)
                chart[i][i + 1][A].append((logp, tree))

            self._apply_unary_closure(chart, i, i + 1)
        for span in range(2, n + 1):
            for i in range(n - span + 1):
                j = i + span

                for k in range(i + 1, j):
                    left_cell = chart[i][k]
                    right_cell = chart[k][j]

                    for B in left_cell:
                        for C in right_cell:
                            if (B, C) not in self.grammar.binary:
                                continue

                            for A, rule_logp in self.grammar.binary[(B, C)]:
                                for p1, t1 in left_cell[B]:
                                    for p2, t2 in right_cell[C]:
                                        total_logp = rule_logp + p1 + p2
                                        tree = (A, t1, t2)
                                        chart[i][j][A].append((total_logp, tree))

                self._prune(chart[i][j])
                self._apply_unary_closure(chart, i, j)

        return chart[0][n].get(self.start, [])

    def _prune(self, cell):
        for nt in cell:
            cell[nt] = sorted(cell[nt], key=lambda x: -x[0])[: self.k]

    def _apply_unary_closure(self, chart, i, j):
        """
        Applies A → B rules transitively until convergence.
        """
        updated = True
        while updated:
            updated = False

            for B in list(chart[i][j].keys()):
                if B not in self.grammar.unary:
                    continue

                for A, rule_logp in self.grammar.unary[B]:
                    for p, subtree in chart[i][j][B]:
                        new_p = rule_logp + p
                        new_tree = (A, subtree)

                        if not self._exists(chart[i][j][A], new_tree):
                            chart[i][j][A].append((new_p, new_tree))
                            updated = True

            self._prune(chart[i][j])

    def _exists(self, entries, tree):
        for _, t in entries:
            if t == tree:
                return True
        return False


In [None]:
from math import log
import re
class PCFG:
    """
    An Implementation of the probabilistic context free grammar with
    Non-terminals left-handside and terminals right handside. These are
    weighted using probabistic value 0 <= x <= 1. The logistic values are
    stored for ease of computation. It stores the values into three separate
    categories binary, unary and lexical
    """

    def __init__(self):
        self.binary = defaultdict(list)
        self.unary = defaultdict(list)
        self.lexical = defaultdict(list)

    def add_rule(self, lhs, rhs, prob):
        logp = log(prob)

        if len(rhs) == 1 and rhs[0].islower():
            self.lexical[rhs[0]].append((lhs, logp))

        elif len(rhs) == 1:
            self.unary[rhs[0]].append((lhs, logp))

        elif len(rhs) == 2:
            self.binary[(rhs[0], rhs[1])].append((lhs, logp))

        else:
            raise ValueError(f"Non-CNF rule passed to PCFG: {lhs} → {rhs}")


RULE_PATTERN = re.compile(
    r"""
    ^\s*
    (?P<lhs>[A-Za-z_]+)
    \s*→\s*
    (?P<rhs>.+?)
    \s*\[(?P<prob>[0-9.]+)\]
    \s*$
    """,
    re.VERBOSE,
)


def binarize_rule(lhs, rhs, prob):
    """
    Converts A → B C D ... into binary CNF rules.
    """
    if len(rhs) <= 2:
        return [(lhs, rhs, prob)]

    rules = []
    current_lhs = lhs
    remaining_prob = prob

    for i in range(len(rhs) - 2):
        new_nt = f"{lhs}_BIN{i}"
        rules.append((current_lhs, [rhs[i], new_nt], remaining_prob))
        current_lhs = new_nt
        remaining_prob = 1.0

    rules.append((current_lhs, rhs[-2:], 1.0))
    return rules


def normalize_symbol(sym):
    # Remove quotes and normalize terminals
    sym = sym.strip()
    if sym.startswith('"') and sym.endswith('"'):
        return sym[1:-1].lower()
    return sym


def load_pcfg_from_file(path: str) -> PCFG:
    grammar = PCFG()

    with open(path, encoding="utf-8") as f:
        for line_no, line in enumerate(f, 1):
            line = line.strip()

            if not line or line.startswith("#"):
                continue

            m = RULE_PATTERN.match(line)
            if not m:
                raise SyntaxError(f"Invalid rule at line {line_no}: {line}")

            lhs = m.group("lhs")
            rhs_raw = m.group("rhs").split()
            rhs = [normalize_symbol(s) for s in rhs_raw]
            prob = float(m.group("prob"))

            cnf_rules = binarize_rule(lhs, rhs, prob)
            for A, B, p in cnf_rules:
                grammar.add_rule(A, B, p)

    return grammar



In [None]:
from nltk import Tree


def pretty_print_tree(tree, indent=0):
    """
    Recursively prints a parse tree in bracketed form.
    """
    space = "  " * indent

    if isinstance(tree, tuple) and len(tree) == 2 and isinstance(tree[1], str):
        print(f"{space}({tree[0]} {tree[1]})")
        return

    if len(tree) == 2:
        print(f"{space}({tree[0]}")
        pretty_print_tree(tree[1], indent + 1)
        print(f"{space})")
        return

    label, left, right = tree
    print(f"{space}({label}")
    pretty_print_tree(left, indent + 1)
    pretty_print_tree(right, indent + 1)
    print(f"{space})")


def to_nltk_tree(tree):
    if len(tree) == 2 and isinstance(tree[1], str):
        return Tree(tree[0], [tree[1]])

    if len(tree) == 2:
        return Tree(tree[0], [to_nltk_tree(tree[1])])

    label, left, right = tree
    return Tree(label, [to_nltk_tree(left), to_nltk_tree(right)])


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:

test_sentences = [
    "the man saw the dog",
    "the dog chased the cat",
    "the man saw the dog with the telescope",
    "the dog saw the man in the park",
    "the big dog saw the small cat",
    "the man and the woman saw the dog",
    "the man quickly saw the dog",
    "the woman saw the man with the telescope",
    "the dog walked in the park",
    "the man liked the dog",
]


results = {}
grammar = load_pcfg_from_file("/content/drive/MyDrive/rules.txt")
# pcfg = PCFG(grammar)
parser = CKYParser(grammar)
for sent in test_sentences:
    parses = parser.parse(sent)
    results[sent] = parses

    print(f"\nSentence: {sent}")
    print(f"Number of parses: {len(parses)}")

    if parses:
        logp, tree = parses[0]
        pretty_print_tree(tree)



Sentence: the man saw the dog
Number of parses: 1
(S
  (NP
    (Det the)
    (N man)
  )
  (VP
    (V saw)
    (NP
      (Det the)
      (N dog)
    )
  )
)

Sentence: the dog chased the cat
Number of parses: 1
(S
  (NP
    (Det the)
    (N dog)
  )
  (VP
    (V chased)
    (NP
      (Det the)
      (N cat)
    )
  )
)

Sentence: the man saw the dog with the telescope
Number of parses: 2
(S
  (NP
    (Det the)
    (N man)
  )
  (VP
    (V saw)
    (NP
      (NP
        (Det the)
        (N dog)
      )
      (PP
        (P with)
        (NP
          (Det the)
          (N telescope)
        )
      )
    )
  )
)

Sentence: the dog saw the man in the park
Number of parses: 2
(S
  (NP
    (Det the)
    (N dog)
  )
  (VP
    (V saw)
    (NP
      (NP
        (Det the)
        (N man)
      )
      (PP
        (P in)
        (NP
          (Det the)
          (N park)
        )
      )
    )
  )
)

Sentence: the big dog saw the small cat
Number of parses: 1
(S
  (NP
    (Det the)
    (NP_

In [None]:
from nltk import Tree

def cky_to_nltk(tree):
    """
    Convert a CKY parse tree into an NLTK Tree.
    """

    # Lexical node: (A, "word")
    if isinstance(tree, tuple) and len(tree) == 2 and isinstance(tree[1], str):
        return Tree(tree[0], [tree[1]])

    # Unary node: (A, child)
    if isinstance(tree, tuple) and len(tree) == 2:
        return Tree(tree[0], [cky_to_nltk(tree[1])])

    # Binary node: (A, left, right)
    label, left, right = tree
    return Tree(label, [cky_to_nltk(left), cky_to_nltk(right)])


In [None]:
parser = CKYParser(grammar)
for sent in test_sentences:
    parses = parser.parse(sent)
    results[sent] = parses

    print(f"\nSentence: {sent}")
    print(f"Number of parses: {len(parses)}")

    if parses:
        logp, tree = parses[0]
        nltk_tree = cky_to_nltk(tree)
        print("NLTK Tree:")
        nltk_tree.pretty_print()



Sentence: the man saw the dog
Number of parses: 1
NLTK Tree:
             S             
      _______|___           
     |           VP        
     |        ___|___       
     NP      |       NP    
  ___|___    |    ___|___   
Det      N   V  Det      N 
 |       |   |   |       |  
the     man saw the     dog


Sentence: the dog chased the cat
Number of parses: 1
NLTK Tree:
              S               
      ________|_____           
     |              VP        
     |         _____|___       
     NP       |         NP    
  ___|___     |      ___|___   
Det      N    V    Det      N 
 |       |    |     |       |  
the     dog chased the     cat


Sentence: the man saw the dog with the telescope
Number of parses: 2
NLTK Tree:
                 S                                
      ___________|_______                          
     |                   VP                       
     |        ___________|___                      
     |       |               NP              

In [None]:
import nltk
from nltk import pos_tag, word_tokenize
from nltk.parse import CoreNLPParser
def nltk_parse(sentence):
    tokens = word_tokenize(sentence)
    tags = pos_tag(tokens)
    return tags

In [None]:
gold = {
    "the man saw the dog": {"S", "NP", "VP"},
    "the man saw the dog with the telescope": {"S", "NP", "VP", "PP"}
}
def extract_constituents(tree, spans=None):
    if spans is None:
        spans = set()

    if isinstance(tree, tuple):
        spans.add(tree[0])
        for child in tree[1:]:
            if isinstance(child, tuple):
                extract_constituents(child, spans)

    return spans


In [None]:
correct = 0
total = 0

for sent, gold_const in gold.items():
    parses = parser.parse(sent)
    if not parses:
        continue

    _, best_tree = parses[0]
    predicted = extract_constituents(best_tree)

    correct += len(predicted & gold_const)
    total += len(gold_const)

accuracy = correct / total
print(f"Parsing accuracy: {accuracy:.2f}")


Parsing accuracy: 1.00


In [None]:
import time
import random

def random_sentence(length, vocab):
    return " ".join(random.choice(vocab) for _ in range(length))

vocab = list(grammar.lexical.keys())
lengths = [5, 10, 15]

for n in lengths:
    sent = random_sentence(n, vocab)
    start = time.time()
    parser.parse(sent)
    elapsed = time.time() - start
    print(f"Length {n}: {elapsed:.4f} seconds")


Length 5: 0.0002 seconds
Length 10: 0.0003 seconds
Length 15: 0.0005 seconds
