In [None]:
import numpy as np
import nltk
from nltk.tokenize import word_tokenize
from nltk.tag import pos_tag
from collections import defaultdict, Counter

nltk.download('punkt')
nltk.download('averaged_perceptron_tagger')
nltk.download('brown')
nltk.download('treebank')
nltk.download('universal_tagset')
nltk.download('punkt_tab')
nltk.download('averaged_perceptron_tagger_eng')

In [None]:
class HMMNextWordPredictor:
    def __init__(self):
        self.corpus = nltk.corpus.brown.sents(categories=['news'])
        self.pos_transitions = defaultdict(Counter)
        self.pos_emissions = defaultdict(Counter)
        self.pos_sequence_to_next_words = defaultdict(Counter)

        self._train()

    def _train(self):
        """Train the HMM model on the corpus."""
        print("Training the model...")

        for sentence in self.corpus:
            if len(sentence) < 3:
                continue

            tagged_sentence = pos_tag(sentence)

            for i in range(len(tagged_sentence) - 1):
                current_word, current_pos = tagged_sentence[i]
                next_word, next_pos = tagged_sentence[i + 1]

                current_word = current_word.lower()
                next_word = next_word.lower()

                self.pos_transitions[current_pos][next_pos] += 1
                self.pos_emissions[current_pos][current_word] += 1

                if i > 0:
                    prev_word, prev_pos = tagged_sentence[i-1]
                    pos_sequence = (prev_pos, current_pos)
                    self.pos_sequence_to_next_words[pos_sequence][next_word] += 1

        self._normalize_probabilities()

        print("Model training complete!")

    def _normalize_probabilities(self):
        """Convert frequency counts to probability distributions."""
        for pos, transitions in self.pos_transitions.items():
            total = sum(transitions.values())
            for next_pos in transitions:
                transitions[next_pos] /= total

        for pos, emissions in self.pos_emissions.items():
            total = sum(emissions.values())
            for word in emissions:
                emissions[word] /= total

        # Normalize POS sequence to next word probabilities
        for pos_seq, next_words in self.pos_sequence_to_next_words.items():
            total = sum(next_words.values())
            for word in next_words:
                next_words[word] /= total

    def predict_next_word(self, text, n=5):
        """
        Predict the most likely next words given the input text.

        Args:
            text (str): Input text
            n (int): Number of predictions to return

        Returns:
            list: Top n predicted words with their probabilities
        """
        tokens = word_tokenize(text)
        if len(tokens) < 2:
            return [("Need at least 2 words for prediction", 0)]

        tagged_tokens = pos_tag(tokens)

        last_pos = tagged_tokens[-1][1]
        second_last_pos = tagged_tokens[-2][1] if len(tagged_tokens) > 1 else None
        pos_sequence = (second_last_pos, last_pos)

        predictions = []

        if pos_sequence in self.pos_sequence_to_next_words:
            candidate_words = self.pos_sequence_to_next_words[pos_sequence]
            predictions.extend([(word, prob) for word, prob in candidate_words.most_common(n)])

        if not predictions or len(predictions) < n:
            next_pos_probs = self.pos_transitions[last_pos]

            for next_pos, trans_prob in next_pos_probs.items():
                for word, emit_prob in self.pos_emissions[next_pos].items():
                    joint_prob = trans_prob * emit_prob
                    predictions.append((word, joint_prob))

            predictions = sorted(predictions, key=lambda x: x[1], reverse=True)[:n]

        return predictions[:n]

    def interactive_prediction(self):
        """Interactive console for next word prediction."""
        print("Welcome to HMM-based Next Word Prediction!")
        print("Type a sentence and get predictions for the next word.")
        print("Type 'exit' to quit.")

        while True:
            text = input("\nEnter text: ")
            if text.lower() == 'exit':
                break

            predictions = self.predict_next_word(text)

            print("\nPredicted next words:")
            for i, (word, prob) in enumerate(predictions, 1):
                print(f"{i}. {word} (probability: {prob:.4f})")

In [None]:
predictor = HMMNextWordPredictor()

test_text = "The president of the"
predictions = predictor.predict_next_word(test_text)

print(f"\nInput: '{test_text}'")
print("Predicted next words:")
for word, prob in predictions:
    print(f"- {word} (probability: {prob:.4f})")

In [None]:
predictor.interactive_prediction()