In [1]:
import time
import re
import math
import string
import nltk
nltk.download('punkt')
from collections import Counter
from nltk.stem.porter import *
from nltk.corpus import stopwords
from sklearn.metrics import f1_score, precision_score, recall_score

data = {}
new_id = -1   # current id
topic_flag = False   # topic has value or not
body_flag = False   # body has value or not
body = ''
id_list = []   # all id

begin = time.time()
filename = ["data/reut2-%03d.sgm" % r for r in range(0, 22)]
for fn in filename:
    print('preprocessing ' + fn + ' begin~~~')
    file = open(fn, 'r', encoding = 'ISO-8859-1')
    line = file.readlines()

    for l in line:
        if l.find('REUTERS TOPICS') != -1:
            if topic_flag:   # 去除前一篇有topic沒有body的部分
                del data[new_id]
                new_id = -1
                topic_flag = False
                
            dtype = re.split('[=]*["]+[ ]*', l)
            if dtype[1] == 'YES':   # 判斷REUTERS TOPICS
                new_id = dtype[9]   # new_id
                data[new_id] = {}
                data[new_id]['type'] = dtype[3]   # LEWISSPLIT => TRAIN or TEST
                
        elif l.find('<TOPICS><D>') != -1 and new_id != -1:
            topic_flag = True
            kind = re.split('<TOPICS><D>|</D><D>|</D></TOPICS>\n', l)   # 切出每個topic
            kind.pop()
            kind.pop(0)
            data[new_id]['category'] = kind   # topic category
            
        elif l.find('<BODY>') != -1 and new_id != -1:
    #         去除沒有topic label的部分
            if not topic_flag:
                del data[new_id]
                new_id = -1
                continue
                
            body_flag = True
            body = re.split('<BODY>', l)[1].strip()   # body的第一行
            continue
            
        elif l.find('</BODY>') != -1 and new_id != -1 and topic_flag:  # body的最後一行
            data[new_id]['body'] = body
            id_list.append(new_id)    # store the new id
            body_flag = False
            topic_flag = False
            new_id = -1
            body = ''

        if body_flag and topic_flag:
            body += l.strip()   # 讀取body每一行並連接在一起
end = time.time()
print('\nread data: ', end - begin, 's')

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\MI\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


preprocessing data/reut2-000.sgm begin~~~
preprocessing data/reut2-001.sgm begin~~~
preprocessing data/reut2-002.sgm begin~~~
preprocessing data/reut2-003.sgm begin~~~
preprocessing data/reut2-004.sgm begin~~~
preprocessing data/reut2-005.sgm begin~~~
preprocessing data/reut2-006.sgm begin~~~
preprocessing data/reut2-007.sgm begin~~~
preprocessing data/reut2-008.sgm begin~~~
preprocessing data/reut2-009.sgm begin~~~
preprocessing data/reut2-010.sgm begin~~~
preprocessing data/reut2-011.sgm begin~~~
preprocessing data/reut2-012.sgm begin~~~
preprocessing data/reut2-013.sgm begin~~~
preprocessing data/reut2-014.sgm begin~~~
preprocessing data/reut2-015.sgm begin~~~
preprocessing data/reut2-016.sgm begin~~~
preprocessing data/reut2-017.sgm begin~~~
preprocessing data/reut2-018.sgm begin~~~
preprocessing data/reut2-019.sgm begin~~~
preprocessing data/reut2-020.sgm begin~~~
preprocessing data/reut2-021.sgm begin~~~

read data:  2.205906629562378 s


In [2]:
def stem(tokens, stemmer):
    stemmed = []
    for item in tokens:
        stemmed.append(stemmer.stem(item))
    return stemmed

def tf(text):
    lowers = text.lower()   # 轉小寫
    
#     remove punctuatuin
    remove_punctuation_map = dict((ord(char), None) for char in string.punctuation)
    no_punctuation = lowers.translate(remove_punctuation_map)
    
    tokens = nltk.word_tokenize(no_punctuation)   # string to token
    
    filtered = [w for w in tokens if not w in stopwords.words('english')]   # stopword remove
    
    stemmer = PorterStemmer()   # stemming
    stemmed = stem(filtered, stemmer)
    
    count = Counter(stemmed)   # count the token
    temp = dict(count)
    
#     1 + log(tf)
    for key in list(temp):
        temp[key] = 1 + math.log(temp[key])
    return temp

In [3]:
idf = {}
train_count = 0   # 訓練集數目

start = time.time()
# train and test set tf
for i in id_list:
    data[i]['body'] = tf(data[i]['body'])
end = time.time()
print('train and test tf: ', end - start, 's')

# df
start = time.time()    
for i in id_list:
    if data[i]['type'] == 'TRAIN':
        train_count += 1
        for key in data[i]['body'].keys():
            if key in idf:
                idf[key] += 1
            else:
                idf[key] = 1

# idf = log(df)
for key in list(idf):
    idf[key] = math.log(train_count / idf[key])
    
end = time.time()
print('idf: ', end - start, 's')

train and test tf:  663.1301655769348 s
idf:  0.382641077041626 s


In [4]:
# tfidf
start = time.time()
for i in id_list:
    for key in data[i]['body'].keys():
        if key in idf:
            data[i]['body'][key] *= idf[key]
        else:   # if this word not in the train set, idf = 0
            data[i]['body'][key] = 0
end = time.time()
print('train and test tfidf: ', end - start, 's')

train and test tfidf:  0.6519205570220947 s


In [5]:
# add the train set tfidf
def combine(a, b):
    x, y = Counter(a), Counter(b)
    return dict(x + y)

In [6]:
bag = {}

file = open('data/all-topics-strings.lc.txt', 'r')   # read the topic file
line = file.readlines()

# bag initial
for l in line:
    bag[l.strip()] = {}
    bag[l.strip()]['body'] = ''   # 此topic的所有文章tfidf
    bag[l.strip()]['count'] = 0   # 此topic包含多少篇文章
    
start = time.time()
for i in id_list:
    if data[i]['type'] == 'TRAIN':
        for t in data[i]['category']:
            bag[t]['body'] = combine(bag[t]['body'], data[i]['body'])   # 合併所有屬於此topic的文章的tfidf
            bag[t]['count'] += 1

for key in bag.keys():
    for word in list(bag[key]['body']):
        bag[key]['body'][word] = bag[key]['body'][word] / bag[key]['count']   # 將tfidf和平均
end = time.time()
print('train the centroid vector: ', end - start, 's')

train the centroid vector:  100.9134087562561 s


In [7]:
def vsm(q):
    ans = {}
    for t in bag.keys():
        dot_product = 0
        doc_V = 0
        query_V = 0
        for key in q.keys():
            if key in bag[t]['body']:
                dot_product += (q[key] * bag[t]['body'][key])
            query_V += (q[key] * q[key])
        query_V = math.sqrt(query_V)
        
        for key in list(bag[t]['body']):
            doc_V += (bag[t]['body'][key] * bag[t]['body'][key])
        doc_V = math.sqrt(doc_V)
        
        if doc_V == 0 or query_V == 0:
            ans[t] = 0
        else:
            ans[t] = (round(dot_product / (doc_V * query_V), 12))
            
    return max(ans, key=ans.get)

In [8]:
test_count = 0   # 測試集數目
correct = 0   # 預測成功數目
test_actually = []   # 實際結果
test_predict = []   # 預測結果

start = time.time()
for i in id_list:
    if data[i]['type'] == 'TEST':
        topic = vsm(data[i]['body'])
        test_count += 1
        
        if topic in data[i]['category']:   # predict correct
            correct += 1
            test_actually.append(topic)
        else:   # predict incorrect
            test_actually.append(data[i]['category'][0])
        test_predict.append(topic)
end = time.time()
print('predict: ', end - start, 's')

predict:  398.45625257492065 s


In [9]:
p = precision_score(test_actually, test_predict, average='macro')
r = recall_score(test_actually, test_predict, average='macro')
f1 = f1_score( test_actually, test_predict, average='macro' )

print('precision: ', p)
print('recall: ', r)
print('F-measure: ', f1)
print('accuracy: ', (correct / test_count))

precision:  0.6519652779805617
recall:  0.6471236170237964
F-measure:  0.6294657662924553
accuracy:  0.8975200583515682


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [None]:
print()