In [1]:
from fastbook import *
from fastai.tabular.all import *
import matplotlib.ticker as tick
from treeinterpreter import treeinterpreter
from datetime import timedelta  
import ipywidgets as widgets
from ipywidgets import interact, interact_manual, VBox, HBox

In [None]:
!pip install voila
!jupyter serverextension enable --sys-prefix voila 

In [6]:
class rf_predict():
    def __init__(self):
        self.rf = None
        self.empty = None
        self.test = None
        self.store = None
        
        self.state_holidays = None
        self.school_holidays = None
        
        self.pred = None
        self.contributions = None
        
        self.output_box = None
        self.predict_btn = None
        self.store_dropdown = None
        self.promo_multiselect = None
        self.opened_multiselect = None
    
    def load_widgets(self):
        self.output_box = widgets.Output()
        self.predict_btn = widgets.Button(
            description="Predict",
            disabled=False)
        
        self.store_dropdown = widgets.Dropdown(
            options=[i for i in range(1,1116)],
            value=1,
            description='Store Number: ',
            disabled=False,
            style=dict(description_width='initial'))
        
        date_options=[(str(d)+'/'+str(m)+'/'+str(y), i) for i,(y,m,d) in 
                      enumerate(zip(self.empty.Year.values, self.empty.Month.values, self.empty.Day.values))]
        
        self.promo_multiselect = widgets.SelectMultiple(
                options=date_options,
                disabled=False,
                style=dict(description_width='initial'))
        
        self.promo_multiselect_HBox = HBox(
            [VBox([
                widgets.Label('Select Promo Dates'), 
                widgets.Label('Hold shift to select multiple')]), 
            self.promo_multiselect]) 
        
        self.opened_multiselect = widgets.SelectMultiple(
                options=date_options,
                disabled=False,
                style=dict(description_width='initial'))
        
        self.opened_multiselect_HBox = HBox(
            [VBox([
                widgets.Label('Select Open Dates'), 
                widgets.Label('Hold shift to select multiple')]), 
            self.opened_multiselect]) 
        
    def load_csv(self):
        self.rf = pickle.load(open('rf_final.sav', 'rb'))
        self.empty = pd.read_csv('empty.csv')
        self.state_holidays = pd.read_csv('state_holidays.csv')
        self.school_holidays = pd.read_csv('school_holidays.csv')
        self.store = pd.read_csv('store.csv')
    
    def prep_empty(self):
        today = datetime.today().date()
        month = [today+timedelta(days=day) for day in range(1,31)]
        self.empty['Date'] = month
        add_datepart(self.empty, "Date", drop=False)
        self.empty.rename({'Dayofweek': 'DayOfWeek'}, axis=1, inplace=True)
    
    def addHolidays(self):
        isStateHoliday = []
        isSchoolHoliday = []
        for month, day in zip(self.empty.Month.values, self.empty.Day.values):
            isStateHoliday.append(1 if (self.state_holidays[['Month','Day']].values == [month,day]).all(axis=1).any() else 0)
            isSchoolHoliday.append(1 if (self.school_holidays[['Month','Day']].values == [month,day]).all(axis=1).any() else 0)

        self.empty.StateHoliday = isStateHoliday   
        self.empty.SchoolHoliday = isSchoolHoliday   
    
    def on_click_classify(self, change):
        self.output_box.clear_output()
        with self.output_box:
            display("Please wait...")
        self.empty['Store'] = self.store_dropdown.value
        self.empty['Promo'] = [1 if i in self.promo_multiselect.value else 0 for i in range(30)]
        self.empty['Open'] = [1 if i in self.opened_multiselect.value else 0 for i in range(30)]
        self.prep_test()
        self.predict()
        self.output_box.clear_output()
        with self.output_box:
            display(self.interactive_pred())

    def interactive_pred(self):
        fig,ax = plt.subplots(figsize=(15, 7))
        ax.set_ylabel('Sales')
        ax.set_xlabel('Date')
        ax.set_title('Predicted Sales')
        plt.plot(self.empty.Date, self.pred.round())
        @interact
        def show_waterfall(date={d:i for i,d in enumerate(self.empty.Date.dt.date.values)}):
            return self.plot(self.test.columns, self.contributions[date], threshold=0.3, 
                  rotation_value=60,formatting='{:,.2f}')

    def get_info(self):
        self.predict_btn.on_click(self.on_click_classify)
        vbox = VBox([self.store_dropdown, self.promo_multiselect_HBox, self.opened_multiselect_HBox, self.predict_btn, self.output_box])
        display(vbox)
        
    def prep_test(self):
        self.test = self.empty.merge(self.store, how='left', left_on='Store', right_on='Store', suffixes=("", '_y'))
        
        self.test['CompetitionOpenSinceYear'] = self.test.CompetitionOpenSinceYear.fillna(1900).astype(np.int32)
        self.test['CompetitionOpenSinceMonth'] = self.test.CompetitionOpenSinceMonth.fillna(1).astype(np.int32)
        self.test['Promo2SinceYear'] = self.test.Promo2SinceYear.fillna(1900).astype(np.int32)
        self.test['Promo2SinceWeek'] = self.test.Promo2SinceWeek.fillna(1).astype(np.int32)
        self.test['CompetitionOpenSince'] = pd.to_datetime(dict(year=self.test.CompetitionOpenSinceYear, 
                                                         month=self.test.CompetitionOpenSinceMonth, day=15))
        self.test["CompetitionDaysOpen"] = self.test.Date.subtract(self.test.CompetitionOpenSince).dt.days
        self.test.loc[self.test.CompetitionDaysOpen<0, "CompetitionDaysOpen"] = 0
        self.test.loc[self.test.CompetitionOpenSinceYear<1990, "CompetitionDaysOpen"] = 0
        self.test["CompetitionMonthsOpen"] = self.test["CompetitionDaysOpen"]//30
        self.test.loc[self.test.CompetitionMonthsOpen>24, "CompetitionMonthsOpen"] = 24
        
        fld = 'SchoolHoliday'
        self.test = self.test.sort_values(['Store', 'Date'])
        self.get_elapsed(fld, 'After')
        self.test = self.test.sort_values(['Store', 'Date'], ascending=[True, False])
        self.get_elapsed(fld, 'Before')

        fld = 'StateHoliday'
        self.test = self.test.sort_values(['Store', 'Date'])
        self.get_elapsed(fld, 'After')
        self.test = self.test.sort_values(['Store', 'Date'], ascending=[True, False])
        self.get_elapsed(fld, 'Before')

        fld = 'Promo'
        self.test = self.test.sort_values(['Store', 'Date'])
        self.get_elapsed(fld, 'After')
        self.test = self.test.sort_values(['Store', 'Date'], ascending=[True, False])
        self.get_elapsed(fld, 'Before')
        
        self.test = self.test.set_index("Date")
        columns = ['SchoolHoliday', 'StateHoliday', 'Promo']
        for o in ['Before', 'After']:
            for p in columns:
                a = o+p
                self.test[a] = self.test[a].fillna(0).astype(int)
                
                
        backward = self.test[['Store']+columns].sort_index().groupby("Store").rolling(7, min_periods=1).sum()
        forward = self.test[['Store']+columns].sort_index(ascending=False).groupby("Store").rolling(7, min_periods=1).sum()

        backward.drop('Store',1,inplace=True)
        backward.reset_index(inplace=True)
        forward.drop('Store',1,inplace=True)
        forward.reset_index(inplace=True)
        self.test.reset_index(inplace=True)

        self.test = self.test.merge(backward, 'left', ['Date', 'Store'], suffixes=['', '_bw'])
        self.test = self.test.merge(forward, 'left', ['Date', 'Store'], suffixes=['', '_fw'])
        self.test.drop(columns,1,inplace=True)

        self.test.Week = self.test.Week.astype("int32")
        
        contin_vars, cat_vars = cont_cat_split(self.test, max_card=9000)

        self.test = TabularPandas(self.test, [Categorify, FillMissing], cat_vars, contin_vars)

        self.test = self.test[['CompetitionDistance', 'Store', 'CompetitionOpenSince', 'BeforePromo',
               'CompetitionOpenSinceMonth', 'AfterPromo', 'StoreType',
               'CompetitionDaysOpen', 'Promo2SinceYear', 'Promo2SinceWeek',
               'Assortment', 'DayOfWeek', 'PromoInterval', 'Dayofyear',
               'AfterStateHoliday', 'Promo_fw', 'Elapsed', 'AfterSchoolHoliday', 'Day',
               'BeforeStateHoliday', 'BeforeSchoolHoliday', 'Promo_bw']]

    def predict(self):
        prediction,bias,contributions = treeinterpreter.predict(self.rf, self.test)
        self.pred = np.exp(prediction)
        self.contributions = contributions
    
    def get_elapsed(self, fld, pre):
        day1 = np.timedelta64(1, 'D')
        last_date = np.datetime64()
        last_store = 0
        res = []

        for s,v,d in zip(self.test.Store.values, self.test[fld].values, self.test.Date.values):
            if s != last_store:
                last_date = np.datetime64()
                last_store = s
            if v: last_date = d
            res.append(((d-last_date).astype('timedelta64[D]') / day1))
        self.test[pre+fld] = res
    
    def plot(self, index, data, Title="", x_lab="", y_lab="",
                  formatting = "{:,.1f}", green_color='#29EA38', red_color='#FB3C62', blue_color='#24CAFF',
                 sorted_value = False, threshold=None, other_label='other', net_label='net', 
                 rotation_value = 30):
        '''
        Given two sequences ordered appropriately, generate a standard waterfall chart.
        Optionally modify the title, axis labels, number formatting, bar colors, 
        increment sorting, and thresholding. Thresholding groups lower magnitude changes
        into a combined group to display as a single entity on the chart.
        '''

        #convert data and index to np.array
        index=np.array(index)
        data=np.array(data)

        #sorted by absolute value 
        if sorted_value: 
            abs_data = abs(data)
            data_order = np.argsort(abs_data)[::-1]
            data = data[data_order]
            index = index[data_order]

        #group contributors less than the threshold into 'other' 
        if threshold:

            abs_data = abs(data)
            threshold_v = abs_data.max()*threshold

            if threshold_v > abs_data.min():
                index = np.append(index[abs_data>=threshold_v],other_label)
                data = np.append(data[abs_data>=threshold_v],sum(data[abs_data<threshold_v]))

        changes = {'amount' : data}

        #define format formatter
        def money(x, pos):
            'The two args are the value and tick position'
            return formatting.format(x)
        formatter = tick.FuncFormatter(money)

        fig,ax = plt.subplots(figsize=(15, 7))
        ax.yaxis.set_major_formatter(formatter)

        #Store data and create a blank series to use for the waterfall
        trans = pd.DataFrame(data=changes,index=index)
        blank = trans.amount.cumsum().shift(1).fillna(0)

        trans['positive'] = trans['amount'] > 0

        #Get the net total number for the final element in the waterfall
        total = trans.sum().amount
        trans.loc[net_label]= total
        blank.loc[net_label] = total

        #The steps graphically show the levels as well as used for label placement
        step = blank.reset_index(drop=True).repeat(3).shift(-1)
        step[1::3] = np.nan

        #When plotting the last element, we want to show the full bar,
        #Set the blank to 0
        blank.loc[net_label] = 0

        #define bar colors for net bar
        trans.loc[trans['positive'] > 1, 'positive'] = 99
        trans.loc[trans['positive'] < 0, 'positive'] = 99
        trans.loc[(trans['positive'] > 0) & (trans['positive'] < 1), 'positive'] = 99

        trans['color'] = trans['positive']

        trans.loc[trans['positive'] == 1, 'color'] = green_color
        trans.loc[trans['positive'] == 0, 'color'] = red_color
        trans.loc[trans['positive'] == 99, 'color'] = blue_color

        my_colors = list(trans.color)

        #Plot and label
        my_plot = plt.bar(range(0,len(trans.index)), blank, width=0.5, color='white')
        plt.bar(range(0,len(trans.index)), trans.amount, width=0.6,
                 bottom=blank, color=my_colors)       


        # connecting lines - figure out later
        #my_plot = lines.Line2D(step.index, step.values, color = "gray")
        #my_plot = lines.Line2D((3,3), (4,4))

        #axis labels
        plt.xlabel("\n" + x_lab)
        plt.ylabel(y_lab + "\n")

        #Get the y-axis position for the labels
        y_height = trans.amount.cumsum().shift(1).fillna(0)

        temp = list(trans.amount)

        # create dynamic chart range
        for i in range(len(temp)):
            if (i > 0) & (i < (len(temp) - 1)):
                temp[i] = temp[i] + temp[i-1]

        trans['temp'] = temp

        plot_max = trans['temp'].max()
        plot_min = trans['temp'].min()

        #Make sure the plot doesn't accidentally focus only on the changes in the data
        if all(i >= 0 for i in temp):
            plot_min = 0
        if all(i < 0 for i in temp):
            plot_max = 0

        if abs(plot_max) >= abs(plot_min):
            maxmax = abs(plot_max)   
        else:
            maxmax = abs(plot_min)

        pos_offset = maxmax / 40

        plot_offset = maxmax / 15 ## needs to me cumulative sum dynamic

        #Start label loop
        loop = 0
        for index, row in trans.iterrows():
            # For the last item in the list, we don't want to double count
            if row['amount'] == total:
                y = y_height[loop]
            else:
                y = y_height[loop] + row['amount']
            # Determine if we want a neg or pos offset
            if row['amount'] > 0:
                y += (pos_offset*2)
                plt.annotate(formatting.format(row['amount']),(loop,y),ha="center", color = 'g', fontsize=9)
            else:
                y -= (pos_offset*4)
                plt.annotate(formatting.format(row['amount']),(loop,y),ha="center", color = 'r', fontsize=9)
            loop+=1

        #Scale up the y axis so there is room for the labels
        plt.ylim(plot_min-round(3.6*plot_offset, 7),plot_max+round(3.6*plot_offset, 7))

        #Rotate the labels
        plt.xticks(range(0,len(trans)), trans.index, rotation=rotation_value)

        #add zero line and title
        plt.axhline(0, color='black', linewidth = 0.6, linestyle="dashed")
        plt.title(Title)
        plt.tight_layout()

        return plt

In [7]:
rf1 = rf_predict()
rf1.load_csv()
rf1.prep_empty()
rf1.load_widgets()
rf1.addHolidays()
rf1.get_info()

VBox(children=(Dropdown(description='Store Number: ', options=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, â€¦