In [1]:
import os
import math
import nltk
from collections import defaultdict, Counter
from sklearn import metrics

def evaluate(test_sentences, tagged_test_sentences):
    gold = [str(tag) for sentence in test_sentences for token, tag in sentence]
    pred = [str(tag) for sentence in tagged_test_sentences for token, tag in sentence]
    print(metrics.classification_report(gold, pred))

def get_token_tag_tuples(sent):
    return [nltk.tag.str2tuple(t) for t in sent.split()]

def get_tagged_sentences(text):
    sentences = []
    blocks = text.split("======================================")
    for block in blocks:
        sents = block.split("\n\n")
        for sent in sents:
            sent = sent.replace("\n", "").replace("[", "").replace("]", "")
            if sent != "":
                sentences.append(sent)
    return sentences

def load_treebank_splits(datadir):
    train = []
    dev = []
    test = []
    print("Loading treebank data...")
    for subdir, dirs, files in os.walk(datadir):
        for filename in files:
            if filename.endswith(".pos"):
                filepath = subdir + os.sep + filename
                with open(filepath, "r") as fh:
                    text = fh.read()
                    if int(subdir.split(os.sep)[-1]) in range(0, 19):
                        train += get_tagged_sentences(text)
                    if int(subdir.split(os.sep)[-1]) in range(19, 22):
                        dev += get_tagged_sentences(text)
                    if int(subdir.split(os.sep)[-1]) in range(22, 25):
                        test += get_tagged_sentences(text)
    print("Train set size: ", len(train))
    print("Dev set size: ", len(dev))
    print("Test set size: ", len(test))
    return train, dev, test

def compute_transition_emission_tables(sentences, alpha=1):
    transition_counts = defaultdict(Counter)
    emission_counts = defaultdict(Counter)
    tag_counts = Counter()

    for sent in sentences:
        tokens_tags = [('<START>', '<START>')] + get_token_tag_tuples(sent) + [('<STOP>', '<STOP>')]
        for i in range(len(tokens_tags) - 1):
            prev_tag, curr_tag = tokens_tags[i][1], tokens_tags[i + 1][1]
            token, tag = tokens_tags[i + 1]
            transition_counts[prev_tag][curr_tag] += 1
            if tag != '<START>' and tag != '<STOP>':  # <START> and <STOP> don't emit tokens
                emission_counts[tag][token] += 1
            tag_counts[tag] += 1

    # Add <START> and <STOP> explicitly
    tag_counts['<START>'] += 1
    tag_counts['<STOP>'] += 1

    # Convert counts to log probabilities
    transition_probs = {tag: {next_tag: math.log((count + alpha) / (sum(transition_counts[tag].values()) + alpha * len(tag_counts)))
                              for next_tag, count in next_tags.items()}
                        for tag, next_tags in transition_counts.items()}

    # Ensure <START> and <STOP> have no emissions
    emission_probs = {tag: {word: math.log((count + alpha) / (tag_counts[tag] + alpha * len(emission_counts[tag])))
                            for word, count in words.items()}
                      for tag, words in emission_counts.items()}
    emission_probs['<START>'] = {}
    emission_probs['<STOP>'] = {}

    return transition_probs, emission_probs


def viterbi(sentence, transition_probs, emission_probs, tags):
    n = len(sentence)
    viterbi_table = defaultdict(lambda: defaultdict(float))
    backpointer = defaultdict(lambda: defaultdict(str))

    # Initialization
    for tag in tags:
        if tag != '<START>':  # <START> doesn't emit tokens
            viterbi_table[0][tag] = transition_probs['<START>'].get(tag, float('-inf')) + emission_probs[tag].get(sentence[0], float('-inf'))
            backpointer[0][tag] = '<START>'

    # Recursion
    for t in range(1, n):
        for tag in tags:
            if tag != '<START>' and tag != '<STOP>':  # Skip invalid transitions
                max_prob, best_prev_tag = max(
                    (viterbi_table[t-1][prev_tag] + transition_probs[prev_tag].get(tag, float('-inf')) + emission_probs[tag].get(sentence[t], float('-inf')),
                     prev_tag) for prev_tag in tags if prev_tag != '<STOP>')
                viterbi_table[t][tag] = max_prob
                backpointer[t][tag] = best_prev_tag

    # Termination
    max_prob, last_tag = max(
        (viterbi_table[n-1][tag] + transition_probs[tag].get('<STOP>', float('-inf')), tag)
        for tag in tags if tag != '<START>'
    )

    # Backtrace
    best_path = []
    current_tag = last_tag
    for t in range(n-1, -1, -1):
        best_path.insert(0, current_tag)
        current_tag = backpointer[t][current_tag]

    return best_path


def baseline_tagger(train_sentences, test_sentences):
    # Build a most frequent tagger
    tag_count = defaultdict(Counter)
    for sent in train_sentences:
        tokens_tags = get_token_tag_tuples(sent)
        for token, tag in tokens_tags:
            tag_count[token][tag] += 1

    most_frequent_tags = {token: tags.most_common(1)[0][0] for token, tags in tag_count.items()}

    tagged_test_sentences = []
    for sent in test_sentences:
        tokens = [token for token, _ in sent]
        predicted_tags = [most_frequent_tags.get(token, 'NN') for token in tokens]
        tagged_test_sentences.append(list(zip(tokens, predicted_tags)))

    return tagged_test_sentences

def main():
    # Set path for datadir
    datadir = "data\penn-treeban3-wsj\wsj"

    train, dev, test = load_treebank_splits(datadir)

    # Compute transition and emission probabilities
    transition_probs, emission_probs = compute_transition_emission_tables(train)
    tags = list(transition_probs.keys())

    # Evaluate on test set with Viterbi decoding
    test_sentences = [get_token_tag_tuples(sent) for sent in test]
    tagged_test_sentences = []
    for sentence in test_sentences:
        tokens = [token for token, _ in sentence]
        predicted_tags = viterbi(tokens, transition_probs, emission_probs, tags)
        tagged_test_sentences.append(list(zip(tokens, predicted_tags)))
    
    print("Viterbi Decoding Evaluation:")
    evaluate(test_sentences, tagged_test_sentences)

    # Baseline tagger evaluation
    print("Baseline Tagger Evaluation:")
    baseline_predictions = baseline_tagger(train, test_sentences)
    evaluate(test_sentences, baseline_predictions)

if __name__ == "__main__":
    main()


Loading treebank data...
Train set size:  51681
Dev set size:  7863
Test set size:  9046
Viterbi Decoding Evaluation:


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


              precision    recall  f1-score   support

                   0.00      0.00      0.00         0
           #       0.00      0.00      0.00        22
           $       0.00      0.00      0.00      1138
          ''       1.00      0.23      0.38      1423
           (       1.00      0.00      0.01       249
           )       1.00      0.22      0.36       252
           ,       1.00      0.00      0.00      9056
           .       1.00      1.00      1.00      7035
           :       1.00      0.04      0.08       983
     <START>       0.00      0.00      0.00         0
          CC       0.00      0.00      0.00      4289
          CD       0.99      0.03      0.05      6023
          DT       0.86      0.01      0.02     14946
          EX       1.00      0.02      0.03       174
          FW       0.50      0.03      0.05        38
          IN       1.00      0.00      0.00     18147
          JJ       0.88      0.00      0.00     10704
         JJR       0.00    

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


              precision    recall  f1-score   support

           #       1.00      1.00      1.00        22
           $       1.00      1.00      1.00      1138
          ''       1.00      0.99      1.00      1423
           (       1.00      1.00      1.00       249
           )       1.00      1.00      1.00       252
           ,       1.00      1.00      1.00      9056
           .       1.00      1.00      1.00      7035
           :       1.00      1.00      1.00       983
          CC       1.00      1.00      1.00      4289
          CD       0.99      0.90      0.94      6023
          DT       0.99      0.99      0.99     14946
          EX       0.89      1.00      0.94       174
          FW       0.38      0.24      0.29        38
          IN       0.94      0.98      0.96     18147
          JJ       0.88      0.86      0.87     10704
         JJR       0.66      0.95      0.78       581
     JJR|RBR       0.00      0.00      0.00         4
         JJS       0.79    