In [67]:
import json
from nltk import ngrams
import nltk
from collections import Counter
from sklearn import tree
from sklearn.model_selection import cross_val_score

In [45]:
def extract_headline_category(category_list):
    headline_category = []
    try:
        input_file = open('./News_Category_Dataset.json')
        input_data = input_file.readlines()
        input_file.close()
        for json_object in input_data:
            data = json.loads(json_object)
            category = data['category'].upper()
            if category in category_list:
                headline_category.append((data['headline'], data['category']))
        return headline_category
    except IOError:
        print("ERROR : IO ERROR occurred while opening file")
        exit(0)


In [46]:
category_list = ['Business', 'Comedy','Sports', 'Crime', 'Religion']
category_list = [i.upper() for i in category_list]
headlines_and_category = extract_headline_category(category_list)

In [47]:
print(headlines_and_category[:20])

[('There Were 2 Mass Shootings In Texas Last Week, But Only 1 On TV', 'CRIME'), ('Rachel Dolezal Faces Felony Charges For Welfare Fraud', 'CRIME'), ("Trump's New 'MAGA'-Themed Swimwear Sinks On Twitter", 'COMEDY'), ('Seth Meyers Has 1 Funny Regret After Trump Cancels North Korea Summit', 'COMEDY'), ('Colbert Wants To Turn NYC Subway Rides Into A New And Terrible Punishment', 'COMEDY'), ("Man Faces Charges After Pulling Knife, Stun Gun On Muslim Students At McDonald's", 'CRIME'), ("Jimmy Kimmel Knows Why Iran's Supreme Leader Watches 'Tom And Jerry'", 'COMEDY'), ('2 People Injured In Indiana School Shooting', 'CRIME'), ("'Late Night' Writer's Breathless Royal Wedding Recap Is The Only One You Need", 'COMEDY'), ("Jets Chairman Christopher Johnson Won't Fine Players For Anthem Protests", 'SPORTS'), ('Colbert Exposes The\xa0Biggest Flaw In Trump’s Latest Conspiracy Theory', 'COMEDY'), ('Seth Meyers Gives Donald Trump Some Valuable Marketing Advice', 'COMEDY'), ('U.S. Launches Auto Import P

In [59]:
print(len(headlines_and_category))

17841


In [48]:
def get_n_grams(dataset:'headline, category', n:'n gram value', k:'return top k n-grams'):
    n_grams_list=[]
    for headline, category in dataset:
        tokenize = nltk.word_tokenize(headline)
        n_gram = nltk.ngrams(tokenize, n)
        n_grams_list.extend(n_gram)
    print(len(n_grams_list))
    top_k = Counter(n_grams_list).most_common(k)
    most_frequent=[i[0] for i in top_k]
    return most_frequent
        

In [49]:
unigrams_dict=  get_n_grams(headlines_and_category, 1, 500)
bigrams_dict = get_n_grams(headlines_and_category, 2, 300)
trigrams_dict = get_n_grams(headlines_and_category, 3, 200)

187445




169606
151792


In [50]:
pos_list = list({'CD', 'CC', 'RP', 'NNPS', 'IN', ',', '$', 'FW', 'RBR', 'JJ', "''", ')', 'VBD', 'VBP', 'POS', ':', 'NNS', '#', 'PRP', '(', 'VBN', 'PDT', 'JJS', 'VBG', 'PRP$', 'RBS', 'LS', '.', 'EX', 'NN', '``', 'DT', 'RB', 'WDT', 'VB', 'UH', 'TO', 'JJR', 'VBZ', 'MD', 'NNP', 'WP', 'WRB'})

In [70]:
def generate_features(dataset, unigrams_dict, bigrams_dict, trigrams_dict, pos_list, category_list):
    X = []
    y = []
    for headline, category in dataset:
        text = nltk.word_tokenize(headline)
        unigrams = nltk.ngrams(text, 1)
        bigrams = nltk.ngrams(text, 2)
        trigrams = nltk.ngrams(text, 3)
        list_of_ngrams_dict = [unigrams_dict, bigrams_dict, trigrams_dict]
        list_of_ngrams_sentences = [unigrams, bigrams, trigrams]
        temp_sentence = []
        for i, ngram in enumerate(list_of_ngrams_dict):
            temp_ngram = [0]*len(ngram)
            for word in list_of_ngrams_sentences[i]:
                if word in ngram:
                    temp_ngram[ngram.index(word)]+=1
            temp_sentence.extend(temp_ngram)
        
    
#         temp_pos = [0]*len(pos_list)
#         for word, tag in nltk.pos_tag(text):
#             temp_pos[pos_list.index(tag)]+=1
#         temp_sentence.extend(temp_pos)
        X.append(temp_sentence)
        y.append(category_list.index(category))
    return (X,y)

In [71]:
X, y = generate_features(headlines_and_category, unigrams_dict, bigrams_dict, trigrams_dict, pos_list, category_list)



In [72]:
print(len(X[0]))

1000


In [73]:
print(y[0])

3


# Decision tree 


In [74]:
clf = tree.DecisionTreeClassifier()

In [75]:
scores = cross_val_score(clf, X, y, scoring = 'accuracy', cv = 10)
print(scores)    

[0.61387801 0.58174692 0.64389698 0.57422969 0.58912556 0.59473094
 0.58721256 0.53142536 0.53535354 0.50617284]


In [76]:
print(sum(scores)/len(scores))

0.5757772401644059
