In [None]:
from collections import defaultdict, Counter
import math
import random


# 1. LOAD CORPUS (word/tag word/tag ...)

def load_tagged_corpus(path="wsj_pos_tagged_en.txt"):
    sentences = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue

            sent = []
            for token in line.split():
                if "/" not in token:
                    continue
                word, tag = token.rsplit("/", 1)
                sent.append((word, tag))

            if sent:
                sentences.append(sent)
    return sentences


# 2. K-FOLD SPLIT
def k_fold_split(data, k=5, seed=42):
    random.Random(seed).shuffle(data)
    n = len(data)
    fold_size = n // k
    folds = []

    for i in range(k):
        start = i * fold_size
        end = (i + 1) * fold_size if i < k - 1 else n
        folds.append(data[start:end])

    return folds


# 3. HMM TRAINING (EMISSION + TRANSITION)
class HMMTagger:
    def __init__(self, smoothing=1.0):
        self.smoothing = smoothing
        self.tags = set()
        self.words = set()
        self.start_symbol = "<s>"
        self.unk = "<UNK>"

        self.tag_counts = Counter()
        self.emission_counts = defaultdict(Counter)
        self.transition_counts = defaultdict(Counter)

    def train(self, sentences):
        for sent in sentences:
            prev_tag = self.start_symbol
            self.tag_counts[prev_tag] += 1

            for word, tag in sent:
                self.tags.add(tag)
                self.words.add(word)

                self.tag_counts[tag] += 1
                self.emission_counts[tag][word] += 1
                self.transition_counts[prev_tag][tag] += 1

                prev_tag = tag

        self.words.add(self.unk)

        # Precompute probabilities
        self.emission_probs = {}
        for tag in self.tags:
            total = sum(self.emission_counts[tag].values())
            V = len(self.words)
            self.emission_probs[tag] = {}
            for w in self.words:
                c = self.emission_counts[tag][w]
                self.emission_probs[tag][w] = (c + self.smoothing) / (total + self.smoothing * V)

        self.transition_probs = {}
        for prev in list(self.transition_counts.keys()) + [self.start_symbol]:
            self.transition_probs[prev] = {}
            total = sum(self.transition_counts[prev].values())
            Vt = len(self.tags)
            for tag in self.tags:
                c = self.transition_counts[prev][tag]
                self.transition_probs[prev][tag] = (c + self.smoothing) / (total + self.smoothing * Vt)

    def _emission(self, tag, word):
        if word not in self.words:
            word = self.unk
        return self.emission_probs[tag].get(word, 1e-12)

    def _transition(self, prev, tag):
        return self.transition_probs.get(prev, {}).get(tag, 1e-12)

    
    # 4. VITERBI DECODING
    def viterbi(self, words):
        T = len(words)
        tags = list(self.tags)

        dp = [defaultdict(lambda: -math.inf) for _ in range(T)]
        bp = [defaultdict(lambda: None) for _ in range(T)]

        # Initialization
        for tag in tags:
            dp[0][tag] = math.log(self._transition(self.start_symbol, tag)) + \
                         math.log(self._emission(tag, words[0]))

        # Recursion
        for t in range(1, T):
            w = words[t]
            for tag in tags:
                best_score = -math.inf
                best_prev = None

                for prev_tag in tags:
                    score = dp[t-1][prev_tag] + math.log(self._transition(prev_tag, tag))
                    if score > best_score:
                        best_score = score
                        best_prev = prev_tag

                dp[t][tag] = best_score + math.log(self._emission(tag, w))
                bp[t][tag] = best_prev

        # Termination
        last_tag = max(dp[-1], key=dp[-1].get)

        # Backtrack
        tags_out = [None] * T
        tags_out[-1] = last_tag
        for t in range(T-1, 0, -1):
            tags_out[t-1] = bp[t][tags_out[t]]

        return tags_out


# 5. EVALUATION: PRECISION, RECALL, F1

def evaluate(gold_sents, pred_sents):
    assert len(gold_sents) == len(pred_sents)

    counts = defaultdict(lambda: {"TP":0, "FP":0, "FN":0})
    tagset = set()

    for gold, pred in zip(gold_sents, pred_sents):
        for g, p in zip(gold, pred):
            tagset.add(g)
            tagset.add(p)
            if g == p:
                counts[g]["TP"] += 1
            else:
                counts[p]["FP"] += 1
                counts[g]["FN"] += 1

    per_tag = {}
    f1s = []

    for tag in sorted(tagset):
        TP = counts[tag]["TP"]
        FP = counts[tag]["FP"]
        FN = counts[tag]["FN"]

        prec = TP / (TP + FP) if TP + FP else 0
        rec = TP / (TP + FN) if TP + FN else 0
        f1 = 2*prec*rec / (prec + rec) if prec + rec else 0

        per_tag[tag] = (prec, rec, f1)
        f1s.append(f1)

    macro_f1 = sum(f1s) / len(f1s)
    return per_tag, macro_f1



# 6. RUN K-FOLD
def run_kfold(k=5):
    data = load_tagged_corpus("wsj_pos_tagged_en.txt")
    folds = k_fold_split(data, k)

    macro_scores = []

    for i in range(k):
        print(f"\n=== Fold {i+1}/{k} ===")

        test = folds[i]
        train = [s for j, fold in enumerate(folds) if j != i for s in fold]

        hmm = HMMTagger(smoothing=1.0)
        hmm.train(train)

        gold_tags = []
        pred_tags = []

        for sent in test:
            words = [w for w, t in sent]
            gold = [t for w, t in sent]
            pred = hmm.viterbi(words)

            gold_tags.append(gold)
            pred_tags.append(pred)

        per_tag, macro_f1 = evaluate(gold_tags, pred_tags)
        print(f"Macro F1 = {macro_f1:.4f}")

        macro_scores.append(macro_f1)

    print("\nAverage Macro F1 Across Folds:", sum(macro_scores)/k)

run_kfold(k=5)


=== Fold 1/5 ===
Macro F1 = 0.2650

=== Fold 2/5 ===
Macro F1 = 0.3468

=== Fold 3/5 ===
Macro F1 = 0.1883

=== Fold 4/5 ===
Macro F1 = 0.2225

=== Fold 5/5 ===
Macro F1 = 0.2626

Average Macro F1 Across Folds: 0.2570221218111244
