In [None]:
import os
from string import Template
from collections import namedtuple

import matplotlib.pyplot as plt
import seaborn as sns

import numpy as np
import pandas as pd

import sys
sys.path.insert(0, '..')
from src.data import TimeSeries
from src.methods.spc import FControlChart, PatternFunction
from src.eval import (
    mean_time_from_event,
    classification_metrics
)

In [None]:
br_train = TimeSeries.from_csv(
    'pandas',
    '../data/blood-refrigerator/train.csv'
)
br_dev = TimeSeries.from_csv(
    'pandas',
    '../data/blood-refrigerator/val.csv'
)
br_test = TimeSeries.from_csv(
    'pandas',
    '../data/blood-refrigerator/test.csv'
)
print(br_train.shape, br_dev.shape, br_test.shape)
br_dev.head()

In [None]:
br_train.parse_datetime('timestamp')
br_dev.parse_datetime('timestamp')
br_test.parse_datetime('timestamp')

br_train.split_by_day()
br_dev.split_by_day()
br_test.split_by_day()

len(br_train.time_series), len(br_dev.time_series), len(br_test.time_series)

In [None]:
splits = [0.25, 0.5, 0.75, 1.0]
br_train_splits = {}

for pct in splits:
    n_days = len(br_train.time_series)
    train_days = list(br_train.time_series.keys())[-int(pct*n_days):]
    y = pd.concat([
        br_train.time_series[k]['PW_0.5h'] for k in train_days
    ])
    X = pd.concat([
        br_train.time_series[k].drop(
            columns=['timestamp','PW_0.5h','date','time']
        ) for k in train_days
    ])

    # Drop std=0 variables
    X = X[
        [c for c in X.columns if np.std(X[c]) != 0]
    ]

    if pct == 0.25:
        keep_cols = X.columns
        
    br_train_splits.update(
        {
            str(pct):
            {
                'X': X.values,
                'y': y.values,
                'cols': keep_cols.tolist()
            }
        }
    )
    print(f"{pct}\t-\t{X.shape}\t-\t{y.shape}\t-\n{keep_cols.tolist()}\n")

In [None]:
br_dev_data = {
    'X': {dt: x[keep_cols].values for dt, x in br_dev.time_series.items()},
    'y': {dt: x['PW_0.5h'].values for dt, x in br_dev.time_series.items()}
}
br_test_data = {
    'X': {dt: x[keep_cols].values for dt, x in br_test.time_series.items()},
    'y': {dt: x['PW_0.5h'].values for dt, x in br_test.time_series.items()}
}

In [None]:
charts = {}

for nm, split in br_train_splits.items():
    print(nm)
    br_chart = FControlChart()
    br_chart.determine_parameters(split['X'])
    charts.update(
        {
            nm: br_chart
        }
    )
    print(br_chart.lcl, br_chart.center_line, br_chart.ucl)

In [None]:
def exceeds_n_breaches(values: np.ndarray, ucl, n):
    if (values > ucl).sum() >= 5:
        return True
    return False

def n_sequential_breaches(values: np.ndarray, ucl, n):
    if (values > ucl).sum() == n:
        return True
    return False 

dev_matches = {}
test_matches = {}

for nm, chart in charts.items():
    print(nm)
    dev_matches[nm] = {}
    test_matches[nm] = {}
    
    for n in [5,10,20,40,80]:
        chart.add_patterns(
            {
                f'{n}per{n*2}at0.05': PatternFunction(
                    exceeds_n_breaches,
                    int(n*2),
                    {'ucl': chart.ucl, 'n':n}
                ),
                f'{n}seqAt0.05': PatternFunction(
                    exceeds_n_breaches,
                    int(n),
                    {'ucl': chart.ucl, 'n':n}
                )
            }
        )
    
    for dt, X in br_dev_data['X'].items():
        matched = chart.check_patterns(X)
        for pattern, res in matched.items():
            if not dev_matches[nm].get(pattern, False):
                dev_matches[nm].update({pattern: {}})
            dev_matches[nm][pattern].update(
                {
                    dt: res
                }
            )
        
    for dt, X in br_test_data['X'].items():
        matched = chart.check_patterns(X)
        for pattern, res in matched.items():
            if not test_matches[nm].get(pattern, False):
                test_matches[nm].update({pattern: {}})
            test_matches[nm][pattern].update(
                {
                    dt: res
                }
            )
    
    print(len(dev_matches), len(test_matches))

In [None]:
res_out = Template(
    "Pattern: $pattern\n"
    "\tMTFE:\t$mtfe\n"
    "\tF1:\t$f1\n"
    "\tRecall:\t$recall\n"
    "\tPrecision:\t$precision\n"
)

ResultTup = namedtuple(
    'ResultTup',
    ['split_pct','pattern','strict','mean_time_from_event','f1','precision','recall']
)

In [None]:
dev_result_for_out = []

print('=== Not Strict ===')
for pct, res in dev_matches.items():
    for pattern, matches in res.items():  
        diffs, mtfe = mean_time_from_event(br_dev_data['y'], matches)
        hits, mets = classification_metrics(br_dev_data['y'], matches)
#         print(
#             res_out.substitute(
#                 pattern=pattern,
#                 mtfe=mtfe,
#                 **mets
#             )
#         )
        dev_result_for_out.append(
            ResultTup(
                pct,
                pattern,
                0,
                mtfe,
                mets['f1'],
                mets['precision'],
                mets['recall']
            )
        )

print('\n=== Strict ===')
for pct, res in dev_matches.items():
    for pattern, matches in res.items():  
        diffs, mtfe = mean_time_from_event(br_dev_data['y'], matches, strict=True)
        hits, mets = classification_metrics(br_dev_data['y'], matches, strict=True)
#         print(
#             res_out.substitute(
#                 pattern=pattern,
#                 mtfe=mtfe,
#                 **mets
#             )
#         )
        dev_result_for_out.append(
            ResultTup(
                pct,
                pattern,
                1,
                mtfe,
                mets['f1'],
                mets['precision'],
                mets['recall']
            )
        )
        
dev_result_df = pd.DataFrame(dev_result_for_out)
dev_result_df

In [None]:
dev_result_df[dev_result_df['strict']==0]['mean_time_from_event'].hist()

In [None]:
dev_result_df[dev_result_df['strict']==1]['mean_time_from_event'].hist()

In [None]:
test_result_for_out = []

print('=== Not Strict ===')
for pct, res in test_matches.items():
    for pattern, matches in res.items():  
        diffs, mtfe = mean_time_from_event(br_test_data['y'], matches)
        hits, mets = classification_metrics(br_test_data['y'], matches)
#         print(
#             res_out.substitute(
#                 pattern=pattern,
#                 mtfe=mtfe,
#                 **mets
#             )
#         )
        test_result_for_out.append(
            ResultTup(
                pct,
                pattern,
                0,
                mtfe,
                mets['f1'],
                mets['precision'],
                mets['recall']
            )
        )

print('\n=== Strict ===')
for pct, res in test_matches.items():
    for pattern, matches in res.items():  
        diffs, mtfe = mean_time_from_event(br_test_data['y'], matches, strict=True)
        hits, mets = classification_metrics(br_test_data['y'], matches, strict=True)
#         print(
#             res_out.substitute(
#                 pattern=pattern,
#                 mtfe=mtfe,
#                 **mets
#             )
#         )
#         test_result_for_out.append(
        ResultTup(
            pct,
            pattern,
            1,
            mtfe,
            mets['f1'],
            mets['precision'],
            mets['recall']
        )
    
test_result_df = pd.DataFrame(test_result_for_out)
test_result_df

In [None]:
test_result_df[test_result_df['strict']==0]['mean_time_from_event'].hist()

In [None]:
test_result_df[test_result_df['strict']==1]['mean_time_from_event'].hist()

In [None]:
lcl = charts['0.25'].lcl
cl = charts['0.25'].center_line
ucl = charts['0.25'].ucl
lcl,cl,ucl

In [None]:
Q = charts['0.25'](br_test_data['X'][datetime.date(2022, 12, 22)])

In [None]:
import datetime
y = br_test_data['y'][datetime.date(2022, 12, 2)]

In [None]:
ax = sns.lineplot(np.log(Q), color=[
    'b' if i == 0 else 'r' for i in y
].tolist())
ax.axhline(y = np.log(lcl),    # Line on y = 0.2
           xmin = 0.1, # From the left
           xmax = 0.9,
           color='r'
          )
ax.axhline(y = np.log(cl),    # Line on y = 0.2
           xmin = 0.1, # From the left
           xmax = 0.9,
           color='orange'
          )
ax.axhline(y = np.log(ucl),    # Line on y = 0.2
           xmin = 0.1, # From the left
           xmax = 0.9,
           color='r'
          )
plt.show()