In [8]:
import nltk
from sklearn_crfsuite import CRF, metrics
from sklearn.model_selection import KFold
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix

# Download dataset
nltk.download('brown')
nltk.download('universal_tagset')
from nltk.corpus import brown

# Step 1: Preprocessing
def preprocess_data():
    data = []
    for sentence in brown.tagged_sents(tagset="universal"):
        tokens = [word.lower() for word, tag in sentence]
        tags = [tag for word, tag in sentence]
        data.append((tokens, tags))
    return data

def extract_features(sentence, i):
    word = sentence[i]
    features = {
        'word': word,
        'is_first': i == 0,
        'is_last': i == len(sentence) - 1,
        'is_capitalized': word[0].upper() == word[0],
        'is_digit': word.isdigit(),
        'prev_word': '' if i == 0 else sentence[i - 1],
        'next_word': '' if i == len(sentence) - 1 else sentence[i + 1],
    }
    return features

def prepare_dataset(data):
    X, y = [], []
    for tokens, tags in data:
        X.append([extract_features(tokens, i) for i in range(len(tokens))])
        y.append(tags)
    return X, y

# Step 2: Model Training and k-Fold Cross-Validation
def train_crf(X, y, k=5):
    kf = KFold(n_splits=k)
    precision, recall, f1_scores = [], [], []
    model = None

    for train_index, test_index in kf.split(X):
        X_train, X_test = [X[i] for i in train_index], [X[i] for i in test_index]
        y_train, y_test = [y[i] for i in train_index], [y[i] for i in test_index]

        model = CRF(algorithm='lbfgs', max_iterations=100, all_possible_transitions=True)
        model.fit(X_train, y_train)

        y_pred = model.predict(X_test)

        report = metrics.flat_classification_report(y_test, y_pred, output_dict=True, digits=3)

        # Overall metrics
        precision.append(report["macro avg"]["precision"])
        recall.append(report["macro avg"]["recall"])
        f1_scores.append(report["macro avg"]["f1-score"])

    return model, precision, recall, f1_scores

# Step 3: Evaluation and Visualization
def generate_confusion_matrix(y_test, y_pred, labels):
    # Flatten the lists of true and predicted labels
    y_true_flat = [item for sublist in y_test for item in sublist]
    y_pred_flat = [item for sublist in y_pred for item in sublist]

    cm = confusion_matrix(y_true_flat, y_pred_flat, labels=labels)

    plt.figure(figsize=(12, 10))
    sns.heatmap(cm, annot=True, fmt="d", xticklabels=labels, yticklabels=labels, cmap="Blues")
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title("Confusion Matrix")
    plt.show()

# Main Execution
data = preprocess_data()
X, y = prepare_dataset(data)

# Perform k-fold cross-validation and train the model
model, precision, recall, f1_scores = train_crf(X, y)

# Overall Metrics
print(f"Precision: {np.mean(precision):.3f}")
print(f"Recall: {np.mean(recall):.3f}")
print(f"F1-Score: {np.mean(f1_scores):.3f}")

# Interactive POS Tagging
print("\nModel trained. You can now input sentences for POS tagging.")
while True:
    sentence = input("\nEnter a sentence for POS tagging (or type 'exit' to quit): ")
    if sentence.lower() == 'exit':
        break
    tokens = sentence.lower().split()
    features = [extract_features(tokens, i) for i in range(len(tokens))]
    predicted_tags = model.predict_single(features)
    print("\nPOS Tags:")
    for token, tag in zip(tokens, predicted_tags):
        print(f"{token}: {tag}")


[nltk_data] Downloading package brown to /root/nltk_data...
[nltk_data]   Package brown is already up-to-date!
[nltk_data] Downloading package universal_tagset to /root/nltk_data...
[nltk_data]   Package universal_tagset is already up-to-date!


Precision: 0.909
Recall: 0.886
F1-Score: 0.894

Model trained. You can now input sentences for POS tagging.

Enter a sentence for POS tagging (or type 'exit' to quit): I am Nitesh Singh

POS Tags:
i: PRON
am: VERB
nitesh: ADJ
singh: NOUN

Enter a sentence for POS tagging (or type 'exit' to quit): exit
