In [1]:
import os
import nltk
from collections import defaultdict, Counter
from sklearn import metrics

In [2]:
def evaluate(test_sentences, tagged_test_sentences):
    gold = [str(tag) for sentence in test_sentences for token, tag in sentence]
    pred = [str(tag) for sentence in tagged_test_sentences for token, tag in sentence]
    print(metrics.classification_report(gold, pred))

def get_token_tag_tuples(sent):
    return [nltk.tag.str2tuple(t) for t in sent.split()]

def get_tagged_sentences(text):
    sentences = []
    blocks = text.split("======================================")
    for block in blocks:
        sents = block.split("\n\n")
        for sent in sents:
            sent = sent.replace("\n", "").replace("[", "").replace("]", "")
            if sent != "":
                sentences.append(sent)
    return sentences

In [3]:
def load_treebank_splits(datadir):
    train = []
    dev = []
    test = []
    print("Loading treebank data...")
    for subdir, dirs, files in os.walk(datadir):
        for filename in files:
            if filename.endswith(".pos"):
                filepath = subdir + os.sep + filename
                with open(filepath, "r") as fh:
                    text = fh.read()
                    if int(subdir.split(os.sep)[-1]) in range(0, 19):
                        train += get_tagged_sentences(text)
                    if int(subdir.split(os.sep)[-1]) in range(19, 22):
                        dev += get_tagged_sentences(text)
                    if int(subdir.split(os.sep)[-1]) in range(22, 25):
                        test += get_tagged_sentences(text)
    print("Train set size: ", len(train))
    print("Dev set size: ", len(dev))
    print("Test set size: ", len(test))
    return train, dev, test

In [4]:
def compute_transition_emission_tables(sentences, alpha=1):
    transition_counts = defaultdict(Counter)
    emission_counts = defaultdict(Counter)
    tag_counts = Counter()

    # Include <START> and <STOP> tags during computation
    for sent in sentences:
        tokens_tags = [('<START>', '<START>')] + get_token_tag_tuples(sent) + [('<STOP>', '<STOP>')]
        for i in range(len(tokens_tags) - 1):
            prev_tag, curr_tag = tokens_tags[i][1], tokens_tags[i + 1][1]
            token, tag = tokens_tags[i + 1]
            transition_counts[prev_tag][curr_tag] += 1
            emission_counts[tag][token] += 1
            tag_counts[tag] += 1

    # Add the <START> and <STOP> tags explicitly to tag_counts
    tag_counts['<START>'] += 1
    tag_counts['<STOP>'] += 1

    # Convert counts to probabilities with add-alpha smoothing
    transition_probs = {tag: {next_tag: (count + alpha) / (sum(transition_counts[tag].values()) + alpha * len(tag_counts))
                              for next_tag, count in next_tags.items()}
                        for tag, next_tags in transition_counts.items()}
    
    # Initialize emission probabilities for <START> and <STOP>
    emission_probs = {tag: {word: (count + alpha) / (tag_counts[tag] + alpha * len(emission_counts[tag]))
                            for word, count in words.items()}
                      for tag, words in emission_counts.items()}
    emission_probs['<START>'] = {}
    emission_probs['<STOP>'] = {}

    return transition_probs, emission_probs


In [5]:
def viterbi(sentence, transition_probs, emission_probs, tags):
    n = len(sentence)
    viterbi_table = defaultdict(lambda: defaultdict(float))
    backpointer = defaultdict(lambda: defaultdict(str))
    
    # Initialization
    for tag in tags:
        if tag != '<START>':  # <START> doesn't emit words
            viterbi_table[0][tag] = transition_probs['<START>'].get(tag, 0) * emission_probs[tag].get(sentence[0], 0)
            backpointer[0][tag] = '<START>'

    # Recursion
    for t in range(1, n):
        for tag in tags:
            max_prob, prev_tag = max(
                (viterbi_table[t-1][prev_tag] * transition_probs[prev_tag].get(tag, 0) * emission_probs[tag].get(sentence[t], 0), prev_tag)
                for prev_tag in tags
            )
            viterbi_table[t][tag] = max_prob
            backpointer[t][tag] = prev_tag
    
    # Termination
    max_prob, last_tag = max(
        (viterbi_table[n-1][tag] * transition_probs[tag].get('<STOP>', 0), tag)
        for tag in tags
    )
    
    # Backtrace
    best_path = []
    current_tag = last_tag
    for t in range(n-1, -1, -1):
        best_path.insert(0, current_tag)
        current_tag = backpointer[t][current_tag]
    
    return best_path

In [6]:
def main():
    # Set path for datadir
    datadir = "data\penn-treeban3-wsj\wsj"

    train, dev, test = load_treebank_splits(datadir)

    # Compute transition and emission probabilities
    transition_probs, emission_probs = compute_transition_emission_tables(train)
    tags = list(transition_probs.keys())

    # Evaluate on test set
    test_sentences = [get_token_tag_tuples(sent) for sent in test]
    tagged_test_sentences = []
    for sentence in test_sentences:
        tokens = [token for token, _ in sentence]
        predicted_tags = viterbi(tokens, transition_probs, emission_probs, tags)
        tagged_test_sentences.append(list(zip(tokens, predicted_tags)))
    
    evaluate(test_sentences, tagged_test_sentences)

if __name__ == "__main__":
    main()


Loading treebank data...
Train set size:  51681
Dev set size:  7863
Test set size:  9046


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


              precision    recall  f1-score   support

           #       1.00      0.27      0.43        22
           $       1.00      0.54      0.70      1138
          ''       1.00      0.65      0.79      1423
           (       1.00      0.51      0.68       249
           )       1.00      0.52      0.69       252
           ,       1.00      0.54      0.70      9056
           .       1.00      0.64      0.78      7035
           :       1.00      0.55      0.71       983
          CC       1.00      0.57      0.72      4289
          CD       1.00      0.56      0.71      6023
          DT       0.99      0.60      0.75     14946
          EX       0.97      0.74      0.84       174
          FW       0.70      0.18      0.29        38
          IN       0.98      0.57      0.72     18147
       IN|RB       0.00      0.00      0.00         0
          JJ       0.90      0.53      0.67     10704
         JJR       0.83      0.58      0.68       581
     JJR|RBR       0.00    