In [1]:
# 標準ライブラリ
import os
import datetime as dt
from datetime import timedelta
from statistics import mean, median,variance,stdev
import itertools
# サードパーティライブラリ
import numpy as np
import burst_detection as bd

## 定数

In [2]:
# numofclasses = "4classes"
FIXED_NUMS = [50, 150, 250] #500
RESULT_DIR = "/Users/daigo/workspace/koyo/result/"
CORRECTS = {
    "tk": {
        "kaede": {"start": dt.date(2015, 12, 4), "end": dt.date(2015, 12, 12)},
        "icho": {"start": dt.date(2015, 11, 30), "end": dt.date(2015, 12, 11)},
        "sonota": {"start": dt.date(2015, 11, 30), "end": dt.date(2015, 12, 12)},
        "koyo": {"start": dt.date(2015, 11, 30), "end": dt.date(2015, 12, 12)}
    },
    "hk": {
        "kaede": {"start": dt.date(2015, 10, 29), "end": dt.date(2015, 11, 29)},
        "icho": {"start": dt.date(2015, 11, 2), "end": dt.date(2015, 11, 12)},
        "sonota": {"start": dt.date(2015, 10, 29), "end": dt.date(2015, 11, 29)},
        "koyo": {"start": dt.date(2015, 10, 29), "end": dt.date(2015, 11, 29)}
    },
    "is": {
        "kaede": {"start": dt.date(2015, 11, 22), "end": dt.date(2015, 12, 3)},
        "icho": {"start": dt.date(2015, 11, 4), "end": dt.date(2015, 11, 10)},
        "sonota": {"start": dt.date(2015, 11, 4), "end": dt.date(2015, 12, 30)},
        "koyo": {"start": dt.date(2015, 11, 4), "end": dt.date(2015, 12, 30)}
    }
}

## 関数宣言

### 連続した日付データ作成
閉区間[start, end]の範囲で出力する

In [3]:
def date_range(start, end):
    for n in range((end - start).days + 1):
        yield start + timedelta(n)

### バースト検出

In [4]:
def burst_detection(target_counts, total_counts):
    #number of target events at each time point
    r = np.array(target_counts, dtype=float)
    #total number of events at each time point
    d = np.array(total_counts, dtype=float)
    #number of time points
    n = len(r)

    q, d, r, p = bd.burst_detection(r,d,n,s=param_s, gamma=param_gamma ,smooth_win=1)
    bursts = bd.enumerate_bursts(q, 'burstLabel')

    #find weight of bursts
    # weighted_bursts = bd.burst_weights(bursts,r,d,p)

    """
    print('observed probabilities: ') 
    print(str(r/d) )
    print( 'optimal state sequence: ')
    print(str(q.T) )
    print( 'baseline probability: ' + str(p[0]) )
    print( 'bursty probability: ' + str(p[1]) )
    """
    
    return(q)

### 評価

In [5]:
def evaluate(dates, correct_dates, burst_result):
    return

### 見頃plot

In [6]:
from matplotlib import pyplot as plt
from matplotlib import dates as mdates
from matplotlib.dates import DateFormatter
import matplotlib.ticker as ticker
from matplotlib.dates import date2num

def plot_migoro(x, y, pref, flag, fixed_num):
    
    y_correct = []
    for date in x:
        if date < CORRECTS[pref][flag]['start'] or CORRECTS[pref][flag]['end'] < date:
            y_correct.append(0)
        elif  CORRECTS[pref][flag]['start'] <= date and date <= CORRECTS[pref][flag]['end']:
            y_correct.append(1)

    figure_ = plt.figure()  # Figure作成
    figure_.text(0.2, 0.5, str(pref) + ": " + str(flag), fontsize=20)
    figure_.text(0.2, 0.4, "fixed_num = " + str(fixed_num), fontsize=20)
    
    axes_ = figure_.add_subplot(111)  # Axes作成
    xaxis_ = axes_.xaxis  # XAxis取得
    
    x_numdate = date2num([
        dt.date(2015, 8, 15), 
        dt.date(2015, 9, 1), dt.date(2015, 9, 15), 
        dt.date(2015, 10, 1), dt.date(2015, 10, 15), 
        dt.date(2015, 11, 1), dt.date(2015, 11, 15), 
        dt.date(2015, 12, 1), dt.date(2015, 12, 15), dt.date(2015, 12, 31)
    ])
    
    axes_.xaxis.set_major_locator(ticker.FixedLocator(x_numdate))
    axes_.tick_params(axis='x', rotation=270)
    xaxis_.set_major_formatter(DateFormatter('%m-%d'))
    
    axes_.plot(x, y_correct, label='correct')
    axes_.plot(x, q, label='burst')
    plt.legend()
    
    fname = f"{flag}_{str(fixed_num).zfill(3)}.png"
    out_dir = f"{RESULT_DIR}graph/{pref}/s{param_s}gamma{param_gamma}/"
    os.makedirs(out_dir, exist_ok=True)
    plt.savefig(out_dir + fname)
    plt.close()

### F値plot

In [7]:
def plot_fscore(fscores, rates):
    return

# main関数

### 変数宣言

In [8]:
prefs = ['tk', 'hk', 'is']
flags = ["icho", "kaede", "sonota", "koyo"] 

param_s = 1.3
param_gamma = 1.0

fscores = []

end = dt.date(2015, 12, 31)
x_axis = [date for date in date_range(dt.date(2015, 10, 1), end)]
x_long_axis = [date for date in date_range(dt.date(2015, 8, 15), end)]

### ファイル読み込み

In [9]:
rtweets_count_dir = RESULT_DIR + "rtweets_count/"
total_count_dir = RESULT_DIR + "total_count/"

target_counts = dict()
total_counts = dict()

for pref in prefs:
    target_counts[pref] = {}
    for flag in flags:
        target_counts[pref].update({flag: []})

**target_countsの構造**

```
target_counts{
    pref: {
        flag: [] # 要素数10でrate10~100のrtweet_countが入っている。
    }
}
```

ややこしい構造なので、クラスを作ってget_count(flag, rate)メソッドで持ってくるようにしたい。

```
class TweetCount: 
    def __init__(self, flag, rate, count_list):
        if flag == icho:
             self.icho = count_list
        elif flag == kaede:
            self.kaede = count_list
        elif flag == koyo:
            self.sonota = count_list
        elif flag == koyo:
            self.koyo = count_list
        else:
            print("error: invalid argument for the class of TweetCount ")
            exit(1)
    def get_count(self, flag, rate):
        return self.flag.rate
```
　rateまで指定しないで良い気がする。10要素のリストが帰ってくれば良いかも

In [10]:
for pref, flag in  itertools.product(prefs, flags):
    for fixed_num in FIXED_NUMS:
        filename = f"{pref}_{flag}_{str(fixed_num).zfill(3)}rwords_count.txt"
        with open(rtweets_count_dir+filename, "r") as target_file:
             target_counts[pref][flag].append([line.rstrip('\n').split('\t')[1] for line in target_file.readlines()])

for pref in prefs:
    filename = pref + "_total_dailycount.txt"
    with open(total_count_dir+filename, "r") as total_file:
        total_counts[pref] = [line.rstrip('\n').split('\t')[1] for line in total_file.readlines()]

### 処理部

In [11]:
for pref, flag in  itertools.product(prefs, flags):
    for i, fixed_num in enumerate(FIXED_NUMS):
        q = burst_detection(target_counts[pref][flag] [i], total_counts[pref] )
        plot_migoro(x_long_axis, q, pref, flag, fixed_num)

### テスト

In [12]:
from pprint import pprint
JST = dt.timezone(dt.timedelta(hours=9))
tokyo = {
"start": dt.datetime(2014, 11, 25, tzinfo=JST), "end": dt.datetime(2014, 12, 18, tzinfo=JST)
}
hokkaido = {
"start": dt.datetime(2014, 10, 30, tzinfo=JST), "end": dt.datetime(2014, 11, 3, tzinfo=JST)
}
ishikawa = {
"start": dt.datetime(2014, 11, 13, tzinfo=JST), "end": dt.datetime(2014, 11, 29, tzinfo=JST)
}

prefs = [tokyo, hokkaido, ishikawa]

for pref in prefs:
    duration =  {
        "created_at_iso": {
            '$gte': pref["start"].isoformat(),
            '$lt': pref["end"].isoformat()
        }
    }
    season = {"icho" : 1}
    
    season.update(duration)
    
    pprint(season)

{'created_at_iso': {'$gte': '2014-11-25T00:00:00+09:00',
                    '$lt': '2014-12-18T00:00:00+09:00'},
 'icho': 1}
{'created_at_iso': {'$gte': '2014-10-30T00:00:00+09:00',
                    '$lt': '2014-11-03T00:00:00+09:00'},
 'icho': 1}
{'created_at_iso': {'$gte': '2014-11-13T00:00:00+09:00',
                    '$lt': '2014-11-29T00:00:00+09:00'},
 'icho': 1}
