### import

In [1]:
# 標準ライブラリ
import os
import datetime as dt
from datetime import timedelta
from statistics import mean, median,variance,stdev
import itertools
# サードパーティライブラリ
import numpy as np
import pandas as pd
import burst_detection as bd

## 定数

In [2]:
FIXED_NUMS = [50, 150, 250, 500]
RESULT_DIR = "/Users/daigo/workspace/koyo/result/"
CORRECTS = {
    "tk": {
        "kaede": {"start": dt.date(2015, 12, 4), "end": dt.date(2015, 12, 12)},
        "icho": {"start": dt.date(2015, 11, 30), "end": dt.date(2015, 12, 11)},
        # "sonota": {"start": dt.date(2015, 11, 30), "end": dt.date(2015, 12, 12)},
        # "koyo": {"start": dt.date(2015, 11, 30), "end": dt.date(2015, 12, 12)}
    },
    "hk": {
        "kaede": {"start": dt.date(2015, 10, 29), "end": dt.date(2015, 11, 29)},
        "icho": {"start": dt.date(2015, 11, 2), "end": dt.date(2015, 11, 12)},
        # "sonota": {"start": dt.date(2015, 10, 29), "end": dt.date(2015, 11, 29)},
        # "koyo": {"start": dt.date(2015, 10, 29), "end": dt.date(2015, 11, 29)}
    },
    "is": {
        "kaede": {"start": dt.date(2015, 11, 22), "end": dt.date(2015, 12, 3)},
        "icho": {"start": dt.date(2015, 11, 4), "end": dt.date(2015, 11, 10)},
        # "sonota": {"start": dt.date(2015, 11, 4), "end": dt.date(2015, 12, 30)},
        # "koyo": {"start": dt.date(2015, 11, 4), "end": dt.date(2015, 12, 30)}
    }
}

### 連続した日付データ作成 date_range()
閉区間[start, end]の範囲で出力する

In [3]:
def date_range(start, end):
    for n in range((end - start).days + 1):
        yield start + timedelta(n)

## 変数

In [4]:
prefs = ['tk', 'hk', 'is']
flags = ["icho", "kaede", "sonota", "koyo"] 
end = dt.date(2015, 12, 31)

x_axis = [date for date in date_range(dt.date(2015, 10, 1), end)]
x_long_axis = [date for date in date_range(dt.date(2015, 8, 15), end)]
dates = x_long_axis  # ここを変えれば期間が変わる（ようにする予定）


param_s = 1.5
param_gamma = 1.0

## 関数宣言

### バースト検出

In [5]:
def burst_detection(target_counts, total_counts):
    #number of target events at each time point
    r = np.array(target_counts, dtype=float)
    #total number of events at each time point
    d = np.array(total_counts, dtype=float)
    #number of time points
    n = len(r)

    q, d, r, p = bd.burst_detection(r,d,n,s=param_s, gamma=param_gamma ,smooth_win=1)
    
    """
    bursts = bd.enumerate_bursts(q, 'burstLabel')
    
    #find weight of bursts
    weighted_bursts = bd.burst_weights(bursts,r,d,p)
    
    print('observed probabilities: ') 
    print(str(r/d) )
    print( 'optimal state sequence: ')
    print(str(q.T) )
    print( 'baseline probability: ' + str(p[0]) )
    print( 'bursty probability: ' + str(p[1]) )
    """
    
    return(q)

### 見頃plot

In [6]:
from matplotlib import pyplot as plt
from matplotlib import dates as mdates
from matplotlib.dates import DateFormatter
import matplotlib.ticker as ticker
from matplotlib.dates import date2num

def plot_migoro(x, q, pref, flag, fixed_num):
    
    y_correct = []
    if flag == "sonota" or flag == "koyo":
        for date in x:
            if ( CORRECTS[pref]["kaede"]['start'] <= date <= CORRECTS[pref]["kaede"]['end']
                or CORRECTS[pref]["icho"]['start'] <= date <= CORRECTS[pref]["icho"]['end']):
                y_correct.append(1)
            else:
                y_correct.append(0)
    else: 
        for date in x:
            if  CORRECTS[pref][flag]['start'] <= date <= CORRECTS[pref][flag]['end']:
                y_correct.append(1)
            else: 
                 y_correct.append(0)
                    
    figure_ = plt.figure()  # Figure作成
    figure_.text(0.2, 0.5, str(pref) + ": " + str(flag), fontsize=20)
    figure_.text(0.2, 0.4, "fixed_num = " + str(fixed_num), fontsize=20)
    
    axes_ = figure_.add_subplot(111)  # Axes作成
    xaxis_ = axes_.xaxis  # XAxis取得
    
    x_numdate = date2num([
        dt.date(2015, 8, 15), 
        dt.date(2015, 9, 1), dt.date(2015, 9, 15), 
        dt.date(2015, 10, 1), dt.date(2015, 10, 15), 
        dt.date(2015, 11, 1), dt.date(2015, 11, 15), 
        dt.date(2015, 12, 1), dt.date(2015, 12, 15), dt.date(2015, 12, 31)
    ])
    
    axes_.xaxis.set_major_locator(ticker.FixedLocator(x_numdate))
    axes_.tick_params(axis='x', rotation=270)
    xaxis_.set_major_formatter(DateFormatter('%m-%d'))
    
    axes_.plot(x, y_correct, label='correct')
    axes_.plot(x, q, label='burst')
    plt.legend()
    
    fname = f"{flag}_{str(fixed_num).zfill(3)}.png"
    out_dir = f"{RESULT_DIR}graph/constant/s{param_s}gamma{param_gamma}/{pref}/"
    os.makedirs(out_dir, exist_ok=True)
    plt.savefig(out_dir + fname)
    # plt.show()
    
    plt.close()

### F値算出<br>evaluate(result, pref, flag)

In [7]:
def evaluate(result, pref, flag):        
    """
    tp: true positives
    tn: true negatives
    fn: false negativess
    fp: false positives
    """
    tp, tn, fn, fp = 0, 0, 0, 0
    if flag in ["icho", "kaede"]:
        correct_period = [date for date in date_range(CORRECTS[pref][flag]['start'], \
                                                      CORRECTS[pref][flag]['end'])]
    else: 
        icho_period = [date for date in date_range(CORRECTS[pref]['icho']['start'], \
                                                      CORRECTS[pref]['icho']['end'])]
        kaede_period = [date for date in date_range(CORRECTS[pref]['kaede']['start'], \
                                                      CORRECTS[pref]['kaede']['end'])]
        
        correct_period = icho_period + kaede_period
        correct_period = list(set(correct_period))
        
        
#         correct_period = []
#         for date in icho_period:
#             if not( date in correct):
#                 correct_period.extend(date)     
    for oneday in dates:
        if oneday in correct_period:
            if result[result['date']==oneday].iloc[0, 1][0] == 1:
                tp += 1
            elif result[result['date']==oneday].iloc[0, 1][0] == 0:
                tn += 1
            else:
                print("error")
        else:
            if result[result['date']==oneday].iloc[0, 1][0] == 1:
                fp += 1
            elif result[result['date']==oneday].iloc[0, 1][0] == 0:
                fn += 1
            else:
                print("error")
    
    if tp == 0:
        recall = precision = fscore = 0
    else:
        recall = tp / len(correct_period)
        precision = tp / (tp + fp)
        fscore = (2 * recall * precision) / (recall + precision)
    
    # 評価値を出す
#     print(f"recall: {recall}")
#     print(f"precision: {precision}")
    print(f"F-score: {fscore}")
    
    return recall, precision, fscore

### F値plot

In [8]:
# def plot_fscore(fscore):


# main関数

### ファイル読み込み

In [11]:
rtweets_count_dir = RESULT_DIR + "rtweets_count_miss/"
total_count_dir = RESULT_DIR + "total_count/"

target_counts = dict()
total_counts = dict()

for pref in prefs:
    target_counts[pref] = {}
    for flag in flags:
        target_counts[pref].update({flag: [[], [], [], []]})
        
target_counts['tk']

{'icho': [[], [], [], []],
 'kaede': [[], [], [], []],
 'sonota': [[], [], [], []],
 'koyo': [[], [], [], []]}

In [12]:
for pref, flag in  itertools.product(prefs, flags):
    for i, fixed_num in enumerate(FIXED_NUMS):
        filename = f"{pref}_{flag}_{str(fixed_num).zfill(3)}rwords_count.txt"
        with open(rtweets_count_dir+filename, "r") as target_file:
             target_counts[pref][flag][i] = [line.rstrip('\n').split('\t')[1] for line in target_file.readlines()]
                
for pref in prefs:
    filename = pref + "_total_dailycount.txt"
    with open(total_count_dir+filename, "r") as total_file:
        total_counts[pref] = [line.rstrip('\n').split('\t')[1] for line in total_file.readlines()]

FileNotFoundError: [Errno 2] No such file or directory: '/Users/daigo/workspace/koyo/result/related_words_miss/tk_icho_050rwords_count.txt'

### 処理部

In [None]:
results = pd.DataFrame(columns=["pref", "flag", "fixed_num", "recall", "precision", "fscore"])

In [None]:
max_score = pd.DataFrame(0.0, columns=prefs, index=flags)
best_fixed_num = pd.DataFrame(0, columns=prefs, index=flags)

for pref, flag in itertools.product(prefs, flags):
    for i, fixed_num in enumerate(FIXED_NUMS):
        q = burst_detection(target_counts[pref][flag][i], total_counts[pref] )
        plot_migoro(dates, q, pref, flag, fixed_num)
        
        q_dateframe = pd.DataFrame([dates, q]).T
        
        q_dateframe.columns = ['date', 'burst']
        recall, precision, fscore = evaluate(q_dateframe, pref, flag)
        
        if  fscore >= max_score[pref][flag]:
            if fscore == max_score[pref][flag] and fscore != 0:
                print("Maxfscoreが等しい条件があります。このMaxが上書きされなければ、書き直して")
            max_score.at[flag, pref] = fscore
            best_fixed_num.at[flag, pref] = fixed_num
        
        result = pd.Series([pref, flag, fixed_num, recall, precision, fscore], index=results.columns)
        results = results.append(result, ignore_index=True)
        
print("---------\n終了\n---------\n")

In [None]:
results.to_csv(f"{RESULT_DIR}graph/constant/s{param_s}gamma{param_gamma}/result.tsv", sep='\t')
results

In [None]:
for pref in prefs:
    print(f"{pref} の最大F値")
    for flag in flags:
        print(f"{flag}: {max_score[pref][flag]} ({best_fixed_num[pref][flag]})")

### テスト

**target_countsの構造**

```
target_counts{
    pref: {
        flag: [] # 要素数10でrate10~100のrtweet_countが入っている。
    }
}
```