In [1]:
# 必要ライブラリ
import numpy as np
from matplotlib import pyplot as plt
from matplotlib import dates as mdates
from statistics import mean, median,variance,stdev
import datetime as dt
from datetime import timedelta
import burst_detection as bd
import numpy as np

In [2]:
#テキスト日付をdate日付にする
def text_date_to_date(l):
    if l[0] <= 9:
        m = '0' + str( l[0] )
    else:
        m = str( l[0] )
    if l[1] <= 9:
        d = '0' + str( l[1] )
    else:
        d = str( l[1] )
    
    d = m + '/' + d + '/2015'
    date = dt.datetime.strptime(d, '%m/%d/%Y').date()
    return(date)

In [3]:
# [n][3]の2次元リストを1次元の直列リストに変更する
def convert_list(lines):
    list_d = []
    for line in lines:
        l_list = line.split('\t')
        list_d.append( [ int(l_list[0]), int(l_list[1]), int(l_list[2]) ] )
        
    return(list_d)

In [4]:
def date_hit_list(list_d):
    list_date_hit = []
    dates = []
    hits = []
    for i, l in enumerate(list_d): # enumerateはインデックス番号も取得したいときにつかう
        date = text_date_to_date(l) 
        dates.append( date )
        hits.append( l[2] )
    list_date_hit.append(dates)
    list_date_hit.append(hits)
    return(list_date_hit)

In [5]:
def daterange(start, end):
  for n in range((end - start).days):
    yield start + timedelta(n)

## kleinbergのバースト

In [6]:
# 検出したバーストをプロットする
def plot_burst(q, dates, pid, rate):
    correct = []
    for date in dates:
        if date > CORRECT_START and date < CORRECT_END:
            correct.append(1)
        else:
            correct.append(0)
    y_correct = np.array(correct[20:90])

    x_date = np.array(dates[20:90])
    y_burst = q.T[0][20:90]

    fig = plt.figure(figsize=(10,4),dpi=150)
    ax = fig.add_subplot(1, 1, 1)
    plt.tight_layout()
    plt.yticks( [0, 1] )
    ax.xaxis.set_major_locator(mdates.DayLocator(bymonthday=None, interval=7, tz=None))
    ax.xaxis.set_major_formatter(mdates.DateFormatter("%m/%d"))

    plt.plot(x_date, y_burst, linewidth=3.0, label='burst')
    plt.plot(x_date, y_correct, linewidth=3.0, label='correct')

    #plt.xlabel('date')
    plt.ylabel('burst')

    plt.tight_layout()

    ax.tick_params(axis='both',labelsize=18)
    fig.autofmt_xdate()
    fig.text(0.12, 0.4, "rate = " + str(rate), fontsize=30)
    plt.legend()

    # 保存
    pngname = pid + '_' + str(rate) + '_burst.png'
    plt.savefig(pngname)

    plt.show()

In [7]:
# バースト検出
def burst_detection(list_d, list_d_all, pid, rate):
    #x軸を作成（日付），y軸を作成（件数），
    d = date_hit_list(list_d)
    dates = d[0]
    hits = d[1]

    d_all = date_hit_list(list_d_all)
    # dates_all = date_hit_list_all[0]
    hits_all = d_all[1]
    
    #number of target events at each time point
    r = np.array(hits, dtype=float)
    #total number of events at each time point
    d = np.array(hits_all, dtype=float)
    #number of time points
    n = len(r)

    q, d, r, p = bd.burst_detection(r,d,n,s=2,gamma=1,smooth_win=1)
    bursts = bd.enumerate_bursts(q, 'burstLabel')

    #find weight of bursts
    weighted_bursts = bd.burst_weights(bursts,r,d,p)

    print('observed probabilities: ')
    # print(str(r/d) )

    print( 'optimal state sequence: ')
    print(str(q.T) )

    print( 'baseline probability: ' + str(p[0]) )

    print( 'bursty probability: ' + str(p[1]) )
    
    plot_burst(q, dates, pid, rate)
    
    return(weighted_bursts)

In [8]:
# バーストと正解期間から求めたF値をプロット
def plot_eva(fscores, rates, pid):
    fig = plt.figure(figsize=(10,10),dpi=150)
    ax = fig.add_subplot(1, 1, 1)

    x = rates
    y = fscores

    ax_x = np.array(x)
    ax_y = np.array(y)

    label = 'Propose method'
    plt.plot(ax_x, ax_y, label = label,  linewidth = 3.0)  
    plt.tight_layout()
    plt.rcParams["font.size"] = 18
    
    shimo_y_value = y[0]
    shimo = []
    for i in range(0, len(y)) :
        shimo.append(shimo_y_value)
        
    shimo_x = np.array(x)
    shimo_y = np.array(shimo)
    plt.plot(shimo_x, shimo_y, label='Target',  linewidth = 3.0)
    
    plt.xlabel('percent')
    plt.ylabel('F-score')

    ax.tick_params(axis='both',labelsize=18)
    fig.autofmt_xdate()
    plt.legend()
    
    # 保存
    plt.savefig(pid + "_relation_fscore.pdf")

    plt.show()

In [9]:
#　検出したバースト期間の評価
def eva(dates, correct, weighted_bursts):
    corrects = dates[dates.index(correct[0]):dates.index(correct[1])] 
    bursts = []
    
    if len(weighted_bursts) > 0:
        for i in range(0, len(weighted_bursts)):
            bursts.extend(dates[weighted_bursts.iloc[i, 1] : weighted_bursts.iloc[i, 2]])
                            
    period = corrects + bursts
    START_DATE = min(period)
    END_DATE = max(period)

    tp, fn, fp = 0, 0, 0
    for date in daterange(START_DATE, END_DATE):
        if date in corrects :
            if date in bursts:
                tp = tp + 1
            else :
                fn = fn + 1
        elif date in bursts:
            fp = fp + 1
            
    rec, pre = 0, 0
    if (tp + fn) > 0:
        rec = tp / (tp + fn)
    if (tp + fp) > 0:
        pre = tp / (tp + fp)
    
    if rec > 0 and pre > 0:
        fscore = (2 * rec * pre) /(rec + pre)
    else:
        fscore = 0
                            
    # 評価値を出す
    print('rec', rec)
    print('pre', pre)
    print('F-score', fscore)
    
    return fscore

In [12]:
# データ準備

# pids =['hk', 'is', 'tk']
# corrects={
#     'hk':{
#         'start': dt.date(2015, 4, 22),
#         'end': dt.date(2015, 4, 26)
#     },
#     'is':{
#         'start': dt.date(2015, 3, 31),
#         'end': dt.date(2015, 4, 4)
#     },
#     'tk':{
#         'start': dt.date(2015, 3, 23),
#         'end': dt.date(2015, 3, 29)
#     }
# }

exts = 'soa'
rates = range(10, 101, 10)
pid = 'tk'

all_tweets = "sample/count/tk_all.txt"
fp_h_all = open(all_tweets, "r")
line_all = fp_h_all.readlines()
fp_h_all.close()

fscores = []

for rate in rates:
    # correct: 桜の見頃正解期間（開花日〜満開日）
    CORRECT_START = text_date_to_date([11, 30])
    CORRECT_END =  text_date_to_date([12, 22])
    correct = [CORRECT_START, CORRECT_END]
    
    
    koyo_tweets = "/Users/daigo/workspace/koyo/2015/count/tk/tk_koyo_soa"+ str(rate) +"count.tsv"

    fp_h = open(koyo_tweets, "r")
    line = fp_h.readlines()
    fp_h.close()

    list_d = convert_list(line)
    list_d_all = convert_list(line_all)

    #x軸を作成（日付），y軸を作成（件数），
    d = date_hit_list(list_d)
    dates = d[0]
    hits = d[1]

    d_all = date_hit_list(list_d_all)
    # dates_all = date_hit_list_all[0]
    hits_all = d_all[1]
    
    burst_result = burst_detection(list_d, list_d_all, pid, rate)
    fscore = eva(dates, correct, burst_result)
    
    fscores.append(fscore)


FileNotFoundError: [Errno 2] No such file or directory: '/Users/daigo/workspace/koyo/2015/count/tk/tk_koyo_soa0count.tsv'

In [None]:
plot_eva(fscores, rates, pid)

In [None]:
# burstラベルから期間を確認する
# dates[65:84]