In [49]:
# Import
import warnings
warnings.filterwarnings("ignore")

import numpy as np
import pandas as pd
from datetime import datetime
    
from datetime import timedelta

from ipywidgets import *
from IPython.display import display

import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

import pickle # Saving of stan models
import pystan

# Data    
# Four years' (209 weeks) records of sales, media impression and media spending at weekly level.   
df = pd.read_csv('data mmm.csv')

In [50]:
# 1. media variables
# media spending
mdsp_cols=[col for col in df.columns if 'mdsp_' in col]

# 2. control variables
# macro economics variables
me_cols = [col for col in df.columns if 'me_' in col]

# store count variables
st_cols = ['st_ct']

# markdown/discount variables
mrkdn_cols = [col for col in df.columns if 'mrkdn_' in col]

# holiday variables
hldy_cols = [col for col in df.columns if 'hldy_' in col]

# seasonality variables
seas_cols = [col for col in df.columns if 'seas_' in col]
base_vars = me_cols+st_cols+mrkdn_cols+hldy_cols+seas_cols

# 3. sales variables
sales_cols =['sales']

In [51]:
# Création de mes propres variables de saisonnalités et jours fériés 
df = df.drop(seas_cols, axis=1)

for i in range(1,11): # de janvier à octobre
    df[f'seas_prd_{i}'] = ((pd.to_datetime(df['wk_strt_dt'].values).month == i) * 1).astype(int)
    
for i in range(42,54): # novembre, décembre
    df[f'seas_week_{i}'] = ((pd.to_datetime(df['wk_strt_dt'].values).isocalendar()['week'] == i).values * 1).astype(int)
    
seas_cols = [col for col in df.columns if 'seas_' in col]
base_vars = me_cols+st_cols+mrkdn_cols+hldy_cols+seas_cols

In [52]:
# 1.1 Adstock
def apply_adstock(x, L, P, D):
    '''
    params:
    x: original media variable, array
    L: length
    P: peak, delay in effect
    D: decay, retain rate
    returns:
    array, adstocked media variable
    '''
    x = np.append(np.zeros(L-1), x) # Insère 0 au début de X
    
    weights = np.zeros(L)
    for l in range(L):
        weight = D**((l-P)**2)
        weights[L-1-l] = weight
    
    adstocked_x = []
    for i in range(L-1, len(x)):
        x_array = x[i-L+1:i+1]
        xi = sum(x_array * weights)/sum(weights)
        adstocked_x.append(xi)
    adstocked_x = np.array(adstocked_x)
    return adstocked_x

def adstock_transform(df, md_cols, adstock_params):
    '''
    params:
    df: original data
    md_cols: list, media variables to be transformed
    adstock_params: dict, 
        e.g., {'sem': {'L': 8, 'P': 0, 'D': 0.1}, 'dm': {'L': 4, 'P': 1, 'D': 0.7}}
    returns: 
    adstocked df
    '''
    md_df = pd.DataFrame()
    for md_col in md_cols:
        md = md_col.split('_')[-1]
        L, P, D = adstock_params[md]['L'], adstock_params[md]['P'], adstock_params[md]['D']
        xa = apply_adstock(df[md_col].values, L, P, D)
        md_df[md_col] = xa
    return md_df

In [53]:
 # 2.1 Control Model / Base Sales Model
# Ventes en fonction des variables macro-économiques

# helper functions
from sklearn.metrics import mean_squared_error

def mean_absolute_percentage_error(y_true, y_pred): # Computation of MAPE (Mean Absolute Percentage Error)
    y_true, y_pred = np.array(y_true), np.array(y_pred)
    return np.mean(np.abs((y_true - y_pred) / y_true)) * 100

def apply_mean_center(x): # The variables are centralized by mean.
    mu = np.mean(x)
    xm = x/mu
    return xm, mu

def mean_center_trandform(df, cols): # The variables are centralized by mean.
    '''
    returns: 
    mean-centered df
    scaler, dict
    '''
    df_new = pd.DataFrame()
    sc = {}
    for col in cols:
        x = df[col].values
        df_new[col], mu = apply_mean_center(x)
        sc[col] = mu
    return df_new, sc

def mean_log1p_trandform(df, cols):
    '''
    returns: 
    mean-centered, log1p transformed df
    scaler, dict
    '''
    df_new = pd.DataFrame()
    sc = {}
    for col in cols:
        x = df[col].values
        xm, mu = apply_mean_center(x)
        sc[col] = mu
        df_new[col] = np.log1p(xm)
    return df_new, sc

In [54]:
# mean-centralize: sales, numeric base_vars
df_ctrl, sc_ctrl = mean_center_trandform(df, ['sales']+me_cols+st_cols+mrkdn_cols)
df_ctrl = pd.concat([df_ctrl, df[hldy_cols+seas_cols]], axis=1)

# variables positively related to sales: macro economy, store count, markdown, holiday
pos_vars = [col for col in base_vars if col not in seas_cols] # Macro-economy, store count, markdown/discount, holidays
X1 = df_ctrl[pos_vars].values

# variables may have either positive or negtive impact on sales: seasonality
pn_vars = seas_cols # Seasonality
X2 = df_ctrl[pn_vars].values

In [55]:
with open("sm1.pkl", "rb") as f1:
    data_dict = pickle.load(f1)
fit1_result = data_dict['fit1_result']

# extract control model parameters and predict base sales -> df['base_sales']
def extract_ctrl_model(fit_result, pos_vars=pos_vars, pn_vars=pn_vars, 
                       extract_param_list=False):
    ctrl_model = {}
    # Positive impact : autant de coef B que de var
    ctrl_model['pos_vars'] = pos_vars  # Macro-economy, store count, markdown/discount, holidays
    #Positive or Negative impact : autant de coef B que de var
    ctrl_model['pn_vars'] = pn_vars # Seasonality
    ctrl_model['beta1'] = fit_result['beta1'].mean(axis=0).tolist()
    ctrl_model['beta2'] = fit_result['beta2'].mean(axis=0).tolist()
    ctrl_model['alpha'] = fit_result['alpha'].mean()
    if extract_param_list:
        ctrl_model['beta1_list'] = fit_result['beta1'].tolist()
        ctrl_model['beta2_list'] = fit_result['beta2'].tolist()
        ctrl_model['alpha_list'] = fit_result['alpha'].tolist()
    return ctrl_model

def ctrl_model_predict(ctrl_model, df):
    pos_vars, pn_vars = ctrl_model['pos_vars'], ctrl_model['pn_vars']
    X1, X2 = df[pos_vars], df[pn_vars]
    beta1, beta2 = np.array(ctrl_model['beta1']), np.array(ctrl_model['beta2'])
    alpha = ctrl_model['alpha']
    y_pred = np.dot(X1, beta1) + np.dot(X2, beta2) + alpha
    return y_pred

base_sales_model = extract_ctrl_model(fit1_result, pos_vars=pos_vars, pn_vars=pn_vars)
base_sales = ctrl_model_predict(base_sales_model, df_ctrl)
df['base_sales'] = base_sales*sc_ctrl['sales']

In [56]:
# 2.2 Marketing Mix Model

# adstock_params = {'dm': {'L': 8, 'P': 1, 'D': 0.5},
#                   'inst': {'L': 8, 'P': 1, 'D': 0.5},
#                   'nsp': {'L': 8, 'P': 1, 'D': 0.5},
#                   'auddig': {'L': 8, 'P': 1, 'D': 0.5},
#                   'audtr': {'L': 8, 'P': 1, 'D': 0.5},
#                   'vidtr': {'L': 8, 'P': 1, 'D': 0.5},
#                   'viddig': {'L': 8, 'P': 1, 'D': 0.5},
#                   'so': {'L': 8, 'P': 1, 'D': 0.5},
#                   'on': {'L': 8, 'P': 1, 'D': 0.5},
#                   'sem': {'L': 8, 'P': 1, 'D': 0.5}}

df_mmm, sc_mmm = mean_log1p_trandform(df, ['sales', 'base_sales'])
max_lag = 8

mu_mdsp = df[mdsp_cols].apply(np.mean, axis=0).values

num_media = len(mdsp_cols)
X_media = np.concatenate((np.zeros((max_lag-1, num_media)), df[mdsp_cols].values[:52*3]), axis=0)

X_ctrl = df_mmm['base_sales'].values.reshape(len(df),1)

In [57]:
# model_data2 = {
#     'N': 52*3, # number of observations
#     'max_lag': max_lag, 
#     'num_media': num_media, # number of media
#     'X_media': X_media, # Adstock sur les montants
#     'mu_mdsp': mu_mdsp, # Adstock sur les montants
#     'num_ctrl': X_ctrl.shape[1], # 1
#     'X_ctrl': X_ctrl[:52*3], # Prédiction base sales du modèle précédent
#     'y': df_mmm['sales'].values[:52*3] # target
# }

# model_code2 = ''' # Format .stan
# functions {
#   // the adstock transformation with a vector of weights
#   real Adstock(vector t, row_vector weights) {
#     return dot_product(t, weights) / sum(weights);
#   }
# }
# data {
#   // the total number of observations
#   int<lower=1> N;
#   // the vector of sales
#   real y[N];
#   // the maximum duration of lag effect, in weeks
#   int<lower=1> max_lag;
#   // the number of media channels
#   int<lower=1> num_media;
#   // matrix of media variables
#   matrix[N+max_lag-1, num_media] X_media;
#   // vector of media variables' mean
#   real mu_mdsp[num_media];
#   // the number of other control variables
#   int<lower=1> num_ctrl;
#   // a matrix of control variables
#   matrix[N, num_ctrl] X_ctrl;
# }
# parameters { 
#   // residual variance
#   real<lower=0> noise_var;
#   // the intercept
#   real tau;
#   // the coefficients for media variables and base sales
#   vector<lower=0>[num_media+num_ctrl] beta; # Force les coef a être positifs
#   // the decay and peak parameter for the adstock transformation of
#   // each media
#   vector<lower=0,upper=1>[num_media] decay;
#   vector<lower=0,upper=ceil(max_lag/2)>[num_media] peak;
# }
# transformed parameters { 
#   // the cumulative media effect after adstock
#   real cum_effect;
#   // matrix of media variables after adstock
#   matrix[N, num_media] X_media_adstocked;
#   // matrix of all predictors
#   matrix[N, num_media+num_ctrl] X;
  
#   // adstock, mean-center, log1p transformation
#   row_vector[max_lag] lag_weights;
#   for (nn in 1:N) {
#     for (media in 1 : num_media) {
#       for (lag in 1 : max_lag) {
#         lag_weights[max_lag-lag+1] <- pow(decay[media], (lag - 1 - peak[media]) ^ 2);
#       }
#      cum_effect <- Adstock(sub_col(X_media, nn, media, max_lag), lag_weights);
#      X_media_adstocked[nn, media] <- log1p(cum_effect/mu_mdsp[media]);
#     }
#   X <- append_col(X_media_adstocked, X_ctrl);
#   } 
  
# }
# model {
#   decay ~ beta(3,3); 

#   peak ~ uniform(0, ceil(max_lag/2)); 
  
#   tau ~ normal(0, 5); 
  
#   for (i in 1 : num_media+num_ctrl) {
#     beta[i] ~ normal(0, 1); # Autant de coef qu'il y a de variable (5 médias + base_sales)
#   }
  
#   noise_var ~ inv_gamma(0.05, 0.05 * 0.01);
  
#   y ~ normal(tau + X * beta, sqrt(noise_var));
# }
# '''

# sm2 = pystan.StanModel(model_code=model_code2, verbose=True)
# fit2 = sm2.sampling(data=model_data2, iter=1000, chains=3) # 1500 coef beta calculés par variable -> on en fait la moyenne
# fit2_result = fit2.extract()

# # Save
# with open("sm2_new_ad.pkl", "wb") as f:
#     pickle.dump({'fit2_result' : fit2_result}, f, protocol=-1)

In [58]:
with open("sm2_new_ad.pkl", "rb") as f2:
    data_dict = pickle.load(f2)
fit2_result = data_dict['fit2_result']

# extract mmm parameters
def extract_mmm(fit_result, media_vars=mdsp_cols, ctrl_vars=['base_sales'],
                extract_param_list=True):
    mmm = {}
    
    mmm['max_lag'] = max_lag
    mmm['media_vars'], mmm['ctrl_vars'] = media_vars, ctrl_vars
    mmm['decay'] = decay = fit_result['decay'].mean(axis=0).tolist()
    mmm['peak'] = peak = fit_result['peak'].mean(axis=0).tolist()
    mmm['beta'] = fit_result['beta'].mean(axis=0).tolist() # 14 valeurs : 13 medias + base sale
    mmm['tau'] = fit_result['tau'].mean()
    if extract_param_list:
        mmm['decay_list'] = fit_result['decay'].tolist()
        mmm['peak_list'] = fit_result['peak'].tolist()
        mmm['beta_list'] = fit_result['beta'].tolist()
        mmm['tau_list'] = fit_result['tau'].tolist()
    
    adstock_params = {}
    media_names = [col.replace('mdsp_', '') for col in media_vars]
    for i in range(len(media_names)):
        adstock_params[media_names[i]] = {
            'L': max_lag,
            'P': peak[i],
            'D': decay[i]
        }
    mmm['adstock_params'] = adstock_params
    return mmm

mmm = extract_mmm(fit2_result, media_vars=mdsp_cols, ctrl_vars=['base_sales'])

adstock_params = mmm['adstock_params']

In [59]:
# for elt in mmm['adstock_params']:
#     print(elt, mmm['adstock_params'][elt])
    
# mdsp_dm {'L': 8, 'P': 1.2949086955068048, 'D': 0.48825528989815126}
# mdsp_inst {'L': 8, 'P': 1.6219506581420853, 'D': 0.5024087650459901}
# mdsp_nsp {'L': 8, 'P': 2.188861273525626, 'D': 0.5037726499565891}
# mdsp_auddig {'L': 8, 'P': 2.0983247194396366, 'D': 0.5077734374302987}
# mdsp_audtr {'L': 8, 'P': 2.0823421637864405, 'D': 0.5085701509642458}
# mdsp_vidtr {'L': 8, 'P': 1.5793084049136936, 'D': 0.4971513430012152}
# mdsp_viddig {'L': 8, 'P': 1.5318323015387405, 'D': 0.48975751024792014}
# mdsp_so {'L': 8, 'P': 2.0237726330777392, 'D': 0.5037517887070598}
# mdsp_on {'L': 8, 'P': 1.5325824492724303, 'D': 0.48849198396192917}
# mdsp_sem {'L': 8, 'P': 1.2760761201093542, 'D': 0.4809640162735905}

In [60]:
def get_mu_var(mmm, df, original_sales=df['sales']):
    # coefficients, intercept
    beta, tau = mmm['beta'], mmm['tau']
    
    # variables
    media_vars, ctrl_vars = mmm['media_vars'], mmm['ctrl_vars']
    num_media, num_ctrl = len(media_vars), len(ctrl_vars)
    
    # X_media2: adstocked, mean-centered media variables + 1
    X_media2 = adstock_transform(df, media_vars, adstock_params)
    X_media2, sc_mmm2 = mean_center_trandform(X_media2, media_vars)
    
    # X_ctrl2, mean-centered control variables + 1
    X_ctrl2, sc_mmm2_1 = mean_center_trandform(df[ctrl_vars], ctrl_vars)
    
    # y_true2, mean-centered sales variable + 1
    y_true2, sc_mmm2_2 = mean_center_trandform(df, ['sales'])
    sc_mmm2.update(sc_mmm2_1)
    sc_mmm2.update(sc_mmm2_2)
    
    return sc_mmm2

mu_var = get_mu_var(mmm, df, original_sales=df['sales'])

In [61]:
from dateutil.easter import *

liste_date_to_mois = [[] for _ in range(12)]
for i in range(12):
    for date in df['wk_strt_dt']:
        mois = str(pd.to_datetime(date).month)
        if str(mois) == str(i+1):
            mois = '0' + str(pd.to_datetime(date).month)
            liste_date_to_mois[i].append(df[df['wk_strt_dt']==date]['wk_strt_dt'].values[0])

def get_last_sunday(date):
    if date is not None:
        return date - timedelta(days=date.weekday()+1)
    
def get_next_week(date):
    if date is not None:
        return date + timedelta(days=7)

def date_to_vac(date_select):
    liste_vac = []
    if date_select is not None:
        date_select = get_next_week(date_select)
        for jour in range(7):
            date = pd.to_datetime(date_select) - pd.Timedelta(jour, unit='days')
            if date.month==12 and date.day == 25:
                liste_vac.append('hldy_Christmas Day')
            if date.month==12 and date.day == 24:
                liste_vac.append('hldy_Christmas Eve')
            if date.month==12 and date.day == 26:
                liste_vac.append('hldy_Day after Christmas')
            if date.month==7 and date.day == 4:
                liste_vac.append('hldy_July 4th')
            if date.month==1 and date.day == 1:
                liste_vac.append('hldy_New Year\'s Day')
            if date.month==12 and date.day == 31:
                liste_vac.append('hldy_NYE')
            if date.month==11 and date.day == 11:
                liste_vac.append('hldy_Veterans Day')
            
            decalage_jeudi = datetime(date.year, 11, 30).weekday() - 3
            if decalage_jeudi >= 0:
                dernier_jeudi = datetime(date.year, 11, 30) - pd.Timedelta(decalage_jeudi, unit='days')
            if decalage_jeudi < 0:
                dernier_jeudi = datetime(date.year, 11, 30) - pd.Timedelta(decalage_jeudi+7, unit='days')
            if (date - pd.Timedelta(2, unit='days')) == dernier_jeudi:
                liste_vac.append('hldy_Thanksgiving')
                liste_vac.append('hldy_Pre Thanksgiving')
                liste_vac.append('hldy_Black Friday') # Dernier vendredi de novembre
                
            if (date - pd.Timedelta(2, unit='days')) == (dernier_jeudi + pd.Timedelta(7, unit='days')):
                liste_vac.append('hldy_Cyber Monday') # Lundi après Black Friday
                
            decalage_lundi = datetime(date.year, 5, 31).weekday()
            if decalage_lundi == 0:
                dernier_lundi = datetime(date.year, 5, 31)
            if decalage_lundi > 0:
                dernier_lundi = datetime(date.year, 5, 31) - pd.Timedelta(decalage_lundi, unit='days')
            if (date - pd.Timedelta(2, unit='days')) == dernier_lundi:
                liste_vac.append('hldy_Memorial Day') # Dernier lundi de mai
                
            if date.month==2 and date.day == 14:
                liste_vac.append('hldy_Valentine\'s Day') 
                
            decalage_lundi = 7 - datetime(date.year, 10, 1).weekday()
            if decalage_lundi == 7:
                deuxieme_lundi = datetime(date.year, 10, 1) + pd.Timedelta(7, unit='days')
            if decalage_lundi < 7:
                deuxieme_lundi = datetime(date.year, 10, 1) + pd.Timedelta(decalage_lundi+7, unit='days')
            if (date - pd.Timedelta(6, unit='days')) == deuxieme_lundi:
                liste_vac.append('hldy_Columbus Day') # 2nd lundi d'octobre
                
            decalage_lundi = 7 - datetime(date.year, 12, 1).weekday() 
            if decalage_lundi == 7:
                deuxieme_lundi = datetime(date.year, 12, 1) + pd.Timedelta(7, unit='days')
            if decalage_lundi < 7:
                deuxieme_lundi = datetime(date.year, 12, 1) + pd.Timedelta(decalage_lundi+7, unit='days')
            if (date - pd.Timedelta(6, unit='days')) == deuxieme_lundi:
                liste_vac.append('hldy_Green Monday') # 2nd lundi de décembre
                
            decalage_lundi = 7 - datetime(date.year, 9, 1).weekday() 
            if decalage_lundi == 7:
                premier_lundi = datetime(date.year, 9, 1)
            if decalage_lundi < 7:
                premier_lundi = datetime(date.year, 9, 1) + pd.Timedelta(decalage_lundi, unit='days')
            if (date - pd.Timedelta(6, unit='days')) == premier_lundi:
                liste_vac.append('hldy_Labor Day') # 1er lundi de septembre
                  
            decalage_lundi = 7 - datetime(date.year, 2, 1).weekday()
            if decalage_lundi == 7:
                troisieme_lundi = datetime(date.year, 2, 1) + pd.Timedelta(14, unit='days')
            if decalage_lundi < 7:
                troisieme_lundi = datetime(date.year, 2, 1) + pd.Timedelta(decalage_lundi+14, unit='days')
            if (date - pd.Timedelta(6, unit='days')) == troisieme_lundi:
                liste_vac.append('hldy_Presidents Day') # 3eme lundi de février    
                
            decalage_lundi = 7 - datetime(date.year, 1, 1).weekday()
            if decalage_lundi == 7:
                troisieme_lundi = datetime(date.year, 1, 1) + pd.Timedelta(14, unit='days')
            if decalage_lundi < 7:
                troisieme_lundi = datetime(date.year, 1, 1) + pd.Timedelta(decalage_lundi+14, unit='days')
            if (date - pd.Timedelta(6, unit='days')) == troisieme_lundi:
                liste_vac.append('hldy_MLK') # 3eme lundi de janvier 
                
            decalage_dimanche = 6 - datetime(date.year, 6, 1).weekday()
            if decalage_dimanche == 0:
                troisieme_dimanche = datetime(date.year, 6, 1) + pd.Timedelta(14, unit='days')
            if decalage_dimanche > 0:
                troisieme_dimanche = datetime(date.year, 6, 1) + pd.Timedelta(decalage_dimanche+14, unit='days')
            if date == troisieme_dimanche:
                liste_vac.append('hldy_Father\'s Day') # 3eme dimanche de juin
                
            decalage_dimanche = 6 - datetime(date.year, 5, 1).weekday()
            if decalage_dimanche == 0:
                second_dimanche = datetime(date.year, 5, 1) + pd.Timedelta(7, unit='days')
            if decalage_dimanche > 0:
                second_dimanche = datetime(date.year, 5, 1) + pd.Timedelta(decalage_dimanche+7, unit='days')
            if date == second_dimanche:
                liste_vac.append('hldy_Mother\'s Day') # 2nd dimanche de mai 
                
            if date == easter(date.year):
                liste_vac.append('hldy_Easter') 

    return liste_vac
    
def date_to_sea(date_select):
    liste_sea = []
    if date_select is not None: # Si on précise une date
        mois = pd.to_datetime(date_select).month
        if mois < 11:
            liste_sea.append(f'seas_prd_{mois}')
        semaine = pd.to_datetime(date_select).isocalendar()[1]
        if semaine > 41:
            liste_sea.append(f'seas_week_{semaine}')
    return liste_sea

In [62]:
dict_trad_media = {'dm':'Email',
                   'inst':'Encart publicitaire',
                   'nsp':'Journaux',
                   'auddig':'Radio',
                   'audtr':'Audio',
                   'vidtr':'Vidéo',
                   'viddig':'Télévision',                                    
                   'so':'Réseaux sociaux',
                   'on':'En ligne',
                   'sem':'Moteur de recherche'}

# Variables globales
beta, tau = mmm['beta'], mmm['tau']

media_vars, ctrl_vars = mdsp_cols, ['base_sales']
num_media, num_ctrl = len(media_vars), len(ctrl_vars)
    
df_ctrl, sc_ctrl = mean_center_trandform(df, ['sales']+me_cols+st_cols+mrkdn_cols)
df_ctrl = pd.concat([df_ctrl, df[hldy_cols+seas_cols]], axis=1)
    
df_ctrl['me_ics_all'] = 1
df_ctrl['me_gas_dpg'] = 1
df_ctrl['st_ct'] = 1
df_ctrl['mrkdn_valadd_edw'] = 1
df_ctrl['mrkdn_pdm'] = 1

media_names = [col.replace('mdsp_', '') for col in media_vars]
    
L, P, D = {}, {}, {}
for md in media_names:
    L[md] = adstock_params[md]['L']
    P[md] = adstock_params[md]['P']
    D[md] = adstock_params[md]['D']

In [67]:
# Prédiction pour '1. Comparaison des médias' : comparaison des ventres des médias en investissant tout le budget dans un média
def get_pred_sales_individuel(date, liste_adstock):           
    # Choix jour férié
    df_ctrl[hldy_cols] = 0
    for vac in date_to_vac(date):
        df_ctrl[vac] = 1
               
    # Choix saisonnalité
    df_ctrl[seas_cols] = 0
    for sea in date_to_sea(date):
        df_ctrl[sea] = 1
         
    baseline = ctrl_model_predict(base_sales_model, df_ctrl)
    baseline = (baseline*df['sales'].mean())/mu_var['base_sales']
    baseline = baseline + 1
    
    data = pd.DataFrame(np.zeros((len(df),num_media+num_ctrl)), columns=media_vars+['base_sales']) 
    
    data['base_sales'] = baseline
    
    # Choix montants des médias   
    for i, md in enumerate(mdsp_cols):
        data[md] = liste_adstock[i]
        
    X_media = data.drop(['base_sales'], axis=1)
    X_media = (X_media/list(mu_var.values())[:10]) + 1

    X = pd.concat([X_media, data['base_sales']], axis=1)

    factor_df = pd.DataFrame(columns=media_vars+ctrl_vars+['intercept'])
    for i in range(num_media):
        colname = media_vars[i]
        factor_df[colname] = X[colname] ** beta[i]

    factor_df['base_sales'] = X['base_sales'] ** beta[num_media]
    factor_df['intercept'] = np.exp(tau)

    y_pred = factor_df.apply(np.prod, axis=1)
    return ((y_pred-1)*mu_var['sales']).iloc[-1]

# Prédiction pour 2. Tous les médias' : combinaison d'investissements des les médias
def get_pred_sales_all(date, liste_adstock, indice):           
    # Choix jour férié
    df_ctrl[hldy_cols] = 0
    for vac in date_to_vac(date):
        df_ctrl[vac] = 1
               
    # Choix saisonnalité
    df_ctrl[seas_cols] = 0
    for sea in date_to_sea(date):
        df_ctrl[sea] = 1
         
    baseline = ctrl_model_predict(base_sales_model, df_ctrl)
    baseline = (baseline*df['sales'].mean())/mu_var['base_sales']
    baseline = baseline + 1
    
    data = pd.DataFrame(np.zeros((len(df),num_media+num_ctrl)), columns=media_vars+['base_sales']) 
    
    data['base_sales'] = baseline
    
    # Choix montants des médias   
    for i, md in enumerate(mdsp_cols):
        if i == indice:
            data[md] = liste_adstock[i]
        else:
            data[md] = 0
    
    X_media = data.drop(['base_sales'], axis=1)
    X_media = (X_media/list(mu_var.values())[:10]) + 1

    X = pd.concat([X_media, data['base_sales']], axis=1)

    factor_df = pd.DataFrame(columns=media_vars+ctrl_vars+['intercept'])
    for i in range(num_media):
        colname = media_vars[i]
        factor_df[colname] = X[colname] ** beta[i]

    factor_df['base_sales'] = X['base_sales'] ** beta[num_media]
    factor_df['intercept'] = np.exp(tau)

    y_pred = factor_df.apply(np.prod, axis=1)
    return ((y_pred-1)*mu_var['sales']).iloc[-1]

def choix_media(type_adstock, validate, nb_date, **dict_autres_var): 
    '''
    Params :
        - dict_autres_var : dict de
                                    - Dates
                                    - Adstock global ('1. Comparaison des médias')
                                    - Adstock de chaque média ('2. Tous les médias')
    '''       
    if validate == 'Valider': 
        
        liste_10_adstock = []
        for i in range(nb_date):
            liste_10_adstock.append([dict_autres_var[f'adstock_dm{i}'],
                                     dict_autres_var[f'adstock_inst{i}'],
                                     dict_autres_var[f'adstock_nsp{i}'],
                                     dict_autres_var[f'adstock_auddig{i}'],
                                     dict_autres_var[f'adstock_audtr{i}'],
                                     dict_autres_var[f'adstock_vidtr{i}'],
                                     dict_autres_var[f'adstock_viddig{i}'],
                                     dict_autres_var[f'adstock_so{i}'],
                                     dict_autres_var[f'adstock_on{i}'],
                                     dict_autres_var[f'adstock_sem{i}']])  
  
        liste_date = [get_last_sunday(dict_autres_var[f'date_select{i}']) for i in range(nb_date)] 
        date_depart = liste_date[0] 
        date_tour_boucle = liste_date[0] 

        # Exceptions sur l'ordre des dates
        for i in range(nb_date-1):
            try:
                if liste_date[i] > liste_date[i+1]: # Exception si date n°1 > date n°2
                    raise ValueError
            except ValueError:
                print(f'\033[91m\033[1mLa date n°{i+2} ({liste_date[i+1]}) doit être supérieure à la date n°{i+1} ({liste_date[i]})\033[0m')
                
        liste_ecart_semaine_historique = [0] # 1er jour
        liste_ecart_semaine_historique += [int((liste_date[i+1]-liste_date[i]).days/7) for i in range(nb_date-1)]
                
        liste_nb_semaine_historique = [8 for _ in range(nb_date)] # De base, il y a 8 semaine d'historique
        for i in range(nb_date):
            if i != 0:
                if (liste_ecart_semaine_historique[i] < 0) or (liste_ecart_semaine_historique[i-1] < 0): # Si date n°i > date n°(i+1)
                    liste_nb_semaine_historique[i] = 0 
                else:
                    liste_nb_semaine_historique[i] = liste_ecart_semaine_historique[i]
                           
        # Simulation de type 1        
        if type_adstock == '1. Comparaison des médias':

            historique = [[dict_autres_var['adstock_global0']]+[0 for _ in range(sum(liste_nb_semaine_historique)-1)] for _ in liste_10_adstock[0]]
            
            for i in range(nb_date-1):
                if liste_date[i] != liste_date[i+1]:
                    for j in range(len(liste_10_adstock[i+1])):
                        historique[j][sum(liste_ecart_semaine_historique[:i+2])] += dict_autres_var[f'adstock_global{i+1}']

            liste_media = [md.split('_')[-1] for md in mdsp_cols]
            liste_adstocks = [apply_adstock(ad, L[md], P[md], D[md]) for md, ad in zip(liste_media,historique)]  
                
            vente_y_plot1, vente_y_plot2, vente_y_plot3, vente_y_base_sale = [], [], [], [] # Ventes du Comparaison des médias + sans montant
            bool_montant = [] # Ventes de base (sans investissement)
            semaine_x, date_x = [], [] # range() des semaines
            indice_semaine = 1
            name_base_sales = []
            
            classement, liste_vente_by_ad = {}, {}
            for md in liste_media:
                classement[md] = 0 # Ventes
                liste_vente_by_ad[md] = []

            print('\033[92m\033[4m\033[1mHistorique :\033[0m')
            for num_semaine_historique in range(nb_date): 
                
                for _ in range(liste_nb_semaine_historique[num_semaine_historique]):# nb de semaine jours d'historique

                    print(f'\nSemaine {indice_semaine} : {date_tour_boucle}')      

                    if date_depart == date_tour_boucle : # Le premier investissement commence
                        print(f'\033[92m\033[1mSemaine du 1er investissement\033[0m')
                    for i in range(1, nb_date):
                        if (date_tour_boucle == liste_date[i]) and (date_depart!=liste_date[i]) and (liste_date[i-1]!=liste_date[i]): # Le deuxième investissement commence
                            print(f'\033[92m\033[1mSemaine du {i+1}ème investissement\033[0m')

                    # Jour férié
                    if date_to_vac(date_tour_boucle) != []:
                        print(f' - {date_to_vac(date_tour_boucle)}')

                    # Saisonnalité
                    if date_to_sea(date_tour_boucle) != []:
                        print(f' - {date_to_sea(date_tour_boucle)}')

                    ad = [adstock[indice_semaine-1] for adstock in liste_adstocks] # Adstock des 10 médias pour la semaine
                    print(f' - Adstock des 10 médias : {[round(val) for val in ad]}\n')

                    for i, md in enumerate(liste_media):
                        vente = get_pred_sales_all(date=date_tour_boucle, liste_adstock=ad, indice=i)
                        classement[md] += vente # Ventes
                        liste_vente_by_ad[md] += [vente] # Liste des ventes (pour chaque média)

                    if sum(ad) > 0:
                        bool_montant = True # Si on investit
                    
                    semaine_x.append(indice_semaine)
                    indice_semaine += 1 # Indice boucle
                    date_x.append(date_tour_boucle)

                    base_sale = get_pred_sales_individuel(date=date_tour_boucle, liste_adstock=[0 for _ in range(10)])
                    vente_y_base_sale.append(base_sale) # Base sale : sans adstock      

                    date_tour_boucle = get_next_week(date_tour_boucle) # On passe à la semaine prochaine

                if num_semaine_historique not in [0, nb_date-1]: # Si pas première ni dernière date 
                    if liste_date[num_semaine_historique] >= liste_date[num_semaine_historique+1]: # & Si ordre des dates correct
                        break
                    
            # Trier par valeur                               
            sorted_values = sorted(zip(classement.values(), classement.keys()), reverse=True) # [(val, md), (val, md),...]
            
            vente_y_plot1 += liste_vente_by_ad[sorted_values[0][1]] # Liste vente du média Top 1
            vente_y_plot2 += liste_vente_by_ad[sorted_values[1][1]] # Liste vente du média Top 2
            vente_y_plot3 += liste_vente_by_ad[sorted_values[2][1]] # Liste vente du média Top 3
                    
            name_base_sales += ['Vente sans montant' for _ in range(sum(liste_nb_semaine_historique))] # Base sale : sans adstock

            if bool_montant: # Si on précise les montants : on affiche les 4 graphiques 
                graphe_top1 = [dict_trad_media[sorted_values[0][1]] for _ in range(sum(liste_nb_semaine_historique))] # Nom du média le plus rentable
                graphe_top2 = [dict_trad_media[sorted_values[1][1]] for _ in range(sum(liste_nb_semaine_historique))] # Nom du 2nd média le plus rentable
                graphe_top3 = [dict_trad_media[sorted_values[2][1]] for _ in range(sum(liste_nb_semaine_historique))] # Nom du 3ème média le plus rentable
                name_graphe = name_base_sales + graphe_top1 + graphe_top2 + graphe_top3
                vente_y = vente_y_base_sale + vente_y_plot1 + vente_y_plot2 + vente_y_plot3
                diff_vente_y = list(np.array(vente_y_plot1) - np.array(vente_y_base_sale)) + list(np.array(vente_y_plot2) - np.array(vente_y_base_sale)) + list(np.array(vente_y_plot3) - np.array(vente_y_base_sale))
                semaine = semaine_x + semaine_x + semaine_x + semaine_x
                date = date_x + date_x + date_x + date_x
            else: # Si pas de montant : graphique juste sur base_sales
                name_graphe = name_base_sales
                vente_y = vente_y_base_sale
                diff_vente_y = np.array(vente_y_base_sale) - np.array(vente_y_base_sale)
                semaine = semaine_x
                date = date_x 
                        
        # Simulation de type 2        
        if type_adstock == '2. Tous les médias':
            historique = [[ad]+[0 for _ in range(sum(liste_nb_semaine_historique)-1)] for ad in liste_10_adstock[0]]
            
            for i in range(nb_date-1):
                if liste_date[i] != liste_date[i+1]:
                    for j, ad in enumerate(liste_10_adstock[i+1]):
                        historique[j][sum(liste_ecart_semaine_historique[:i+2])] += ad # Somme des écarts des date n°(i+1) à n°(i+2)
                                                   
            liste_adstocks = [apply_adstock(ad, L, P, D) for ad in historique] # Adstock pour chaque média
                        
            vente_y_plot, vente_y_base_sale = [], []
            semaine_x, date_x = [], [] # Pour le graphique
            indice_semaine = 1
            name_base_sales, name_all_medias = [], [] # 'Vente sans montant', 'Tous les médias'

            print('\033[92m\033[4m\033[1mHistorique :\033[0m')
            for num_semaine_historique in range(nb_date):  
                                                    
                for _ in range(liste_nb_semaine_historique[num_semaine_historique]):# nb de semaine jours d'historique

                    print(f'\nSemaine {indice_semaine} : {date_tour_boucle}')      
       
                    if date_depart == date_tour_boucle : # Le premier investissement commence
                        print(f'\033[92m\033[1mSemaine du 1er investissement\033[0m')
                    for i in range(1, nb_date):
                        if (date_tour_boucle == liste_date[i]) and (date_depart!=liste_date[i]) and (liste_date[i-1]!=liste_date[i]): # Le deuxième investissement commence
                            print(f'\033[92m\033[1mSemaine du {i+1}ème investissement\033[0m')

                    # Jour férié
                    if date_to_vac(date_tour_boucle) != []:
                        print(f' - {date_to_vac(date_tour_boucle)}')

                    # Saisonnalité
                    if date_to_sea(date_tour_boucle) != []:
                        print(f' - {date_to_sea(date_tour_boucle)}')

                    ad = [adstock[indice_semaine-1] for adstock in liste_adstocks] # Adstock des 10 médias pour la semaine
                    print(f' - Adstock des 10 médias : {[round(val) for val in ad]}\n')
                
                    # Avec montant (adstocks)
                    vente = get_pred_sales_individuel(date=date_tour_boucle, liste_adstock=ad)
                    vente_y_plot.append(vente)

                    # Sans montant (base sales)
                    base_sale = get_pred_sales_individuel(date=date_tour_boucle, liste_adstock=[0 for _ in range(10)])
                    vente_y_base_sale.append(base_sale) # Base sale : sans adstock      

                    semaine_x.append(indice_semaine)
                    indice_semaine += 1 # Indice boucle
                    date_x.append(date_tour_boucle) # Date boucle : on passe à la semaine prochaine
                    date_tour_boucle = get_next_week(date_tour_boucle)
                    
                name_base_sales += ['Vente sans montant' for _ in range(liste_nb_semaine_historique[num_semaine_historique])] # Base sale : sans adstock
                               
                if np.sum(liste_adstocks) != 0 : # Si on investit
                    name_all_medias += ['Tous les médias' for _ in range(liste_nb_semaine_historique[num_semaine_historique])] 
                    name_graphe = name_base_sales + name_all_medias
                    vente_y = vente_y_base_sale + vente_y_plot
                    diff_vente_y = np.array(vente_y_plot) - np.array(vente_y_base_sale)
                    semaine = semaine_x + semaine_x # Semaine pour les médias et base sales
                    date = date_x + date_x
                else: # Si on n'investit pas
                    name_graphe = name_base_sales
                    vente_y = vente_y_base_sale
                    diff_vente_y = np.array(vente_y_base_sale) - np.array(vente_y_base_sale)
                    semaine = semaine_x # Semaine seulement pour base sales
                    date = date_x
               
                if num_semaine_historique not in [0, nb_date-1]: # Si pas première ni dernière date 
                    if liste_date[num_semaine_historique] >= liste_date[num_semaine_historique+1]: # & Si ordre des dates correct
                        break
                        
        # Données des ventes
        data = {'Vente':vente_y,'Média':name_graphe, 'Semaine':semaine, 'Date':date}
        data = pd.DataFrame(data=data) 
        
        # Données des différences de ventes
        if len(diff_vente_y) == len(name_graphe):
            data_diff = {'Gain de vente':diff_vente_y, 'Média':name_graphe, 
                         'Semaine':semaine, 'Date':date}
        else:
            data_diff = {'Gain de vente':diff_vente_y, 'Média':name_graphe[indice_semaine-1:], 
                         'Semaine':semaine[indice_semaine-1:], 'Date':date[indice_semaine-1:]}
        data_diff = pd.DataFrame(data=data_diff) 

        # Graphique des ventes
        config = dict({'scrollZoom': True})

        fig = px.line(data, x='Semaine', y='Vente', color='Média',
                                      title=f'Évolution des ventes',
                                      hover_name="Média", hover_data={"Date":True, 
                                                                      "Semaine":True,
                                                                      "Vente":True,
                                                                      "Média":False},
                                      width=700, height=500)

        fig.update_traces(mode="markers+lines")    
        fig.update_layout(xaxis_title='Semaine', yaxis_title="Vente")
        fig.show(config=config)
        
        # Résumé des ventes
        print('\033[92m\033[4m\033[1mSomme des ventes :\033[0m')
        media = data[['Média','Vente']].groupby('Média').sum().sort_values(by=['Vente'], ascending=False).index
        vente = data[['Média','Vente']].groupby('Média').sum().sort_values(by=['Vente'], ascending=False).values
        for md, vt in zip(media,  vente):
            print(f'{md} : {round(vt[0]):,}') 
            
        # Graphique de la différence entre vente de base et avec média
        config = dict({'scrollZoom': True})

        fig = px.line(data_diff, x='Semaine', y='Gain de vente', color='Média',
                                      title='Évolution du gain des ventes',
                                      hover_data={"Date":True, 
                                                  "Semaine":True,
                                                  "Gain de vente":True,
                                                  "Média":False},
                                      width=700, height=500)

        fig.update_traces(mode="markers+lines")    
        fig.update_layout(xaxis_title='Semaine', yaxis_title="Gain de vente")
        fig.show(config=config)
        
        # Résumé des gains entre vente de base et avec média
        print('\033[92m\033[4m\033[1mSomme des gains de vente des médias :\033[0m')
        media = data_diff[['Média','Gain de vente']].groupby('Média').sum().sort_values(by=['Gain de vente'], ascending=False).index
        vente = data_diff[['Média','Gain de vente']].groupby('Média').sum().sort_values(by=['Gain de vente'], ascending=False).values
        for md, vt in zip(media,  vente):
            print(f'{md} : {round(vt[0]):,}') 

In [68]:
def simulateur(nb_date):    
    '''
    Params : 
        - nb_date :  nombre de jours d'investissement (= nombre de fenêtre de paramètres)
    '''
    ###### WIDGETS ######
    
    nb_date_widget = IntSlider(value=nb_date) # Remplace nb_date par un widget car la fonction choix_media attend un widget

    liste_widgets_date, liste_widgets_ad = [], []
    for i in range(1,nb_date+1): # 1, 2
        # Relatif à la date n°i
        liste_widgets_date.append(widgets.DatePicker(description=f'Date n°{i}',disabled=False, 
                                 value = datetime.now().date(), layout={'width': 'max-content'})) # Date n°i
        liste_widgets_ad.append(IntSlider(value=0,min=0,max=10000000, description='1. Comparaison des médias : ', continuous_update=False, style=style, layout={'width': '350px'}))

    liste_widgets_adstock = []
    for _ in range(nb_date):
        # Relatif à la date n°_
        liste_widgets_adstock.append(IntSlider(value=0,min=0,max=10000000, description='Email', continuous_update=False, style=style, orientation='vertical', layout={'width': 'max-content'}))
        liste_widgets_adstock.append(IntSlider(value=0,min=0,max=10000000, description='Encart publicitaire', continuous_update=False, style=style, orientation='vertical', layout={'width': 'max-content'}))
        liste_widgets_adstock.append(IntSlider(value=0,min=0,max=10000000, description='Journaux', continuous_update=False, style=style, orientation='vertical', layout={'width': 'max-content'}))
        liste_widgets_adstock.append(IntSlider(value=0,min=0,max=10000000, description='Radio', continuous_update=False, style=style, orientation='vertical', layout={'width': 'max-content'}))
        liste_widgets_adstock.append(IntSlider(value=0,min=0,max=10000000, description='Audio', continuous_update=False, style=style, orientation='vertical', layout={'width': 'max-content'}))
        liste_widgets_adstock.append(IntSlider(value=0,min=0,max=10000000, description='Vidéo', continuous_update=False, style=style, orientation='vertical', layout={'width': 'max-content'}))
        liste_widgets_adstock.append(IntSlider(value=0,min=0,max=10000000, description='Télévision', continuous_update=False, style=style, orientation='vertical', layout={'width': 'max-content'}))
        liste_widgets_adstock.append(IntSlider(value=0,min=0,max=10000000, description='Réseaux sociaux', continuous_update=False, style=style, orientation='vertical', layout={'width': 'max-content'}))
        liste_widgets_adstock.append(IntSlider(value=0,min=0,max=10000000, description='En ligne', continuous_update=False, style=style, orientation='vertical', layout={'width': 'max-content'}))
        liste_widgets_adstock.append(IntSlider(value=0,min=0,max=10000000, description='Moteur de recherche', continuous_update=False, style=style, orientation='vertical', layout={'width': 'max-content'}))        

    liste_keys_adstock = []
    for i in range(nb_date): # Clés du dictionnaire des adstocks
        liste_keys_adstock = liste_keys_adstock+[f'adstock_dm{i}', f'adstock_inst{i}', f'adstock_nsp{i}', f'adstock_auddig{i}', 
                                                 f'adstock_audtr{i}', f'adstock_vidtr{i}', f'adstock_viddig{i}', f'adstock_so{i}',
                                                 f'adstock_on{i}', f'adstock_sem{i}']    
    
    liste_widgets_date, liste_widgets_ad = [], []
    for i in range(nb_date): 
        # Relatif à la date n°(i+1)
        liste_widgets_date.append(widgets.DatePicker(description=f'Date n°{i+1}',disabled=False, 
                                     value = datetime.now().date(), layout={'width': 'max-content'})) # Date n°(i+1)
        liste_widgets_ad.append(IntSlider(value=0,min=0,max=10000000, description='1. Comparaison des médias : ', continuous_update=False, style=style, layout={'width': '350px'}))

    validate = ToggleButtons(options=['Valider', 'Bloquer'], button_style='success', tooltips=['Lancer la simulation', 'Bloquer la simulation (pour ajuster les paramètres)'])   
    type_adstock = Dropdown(options=['1. Comparaison des médias','2. Tous les médias',], value='1. Comparaison des médias', description='Type de simulation', disabled=False, style=style, layout={'width': 'max-content'})

    dict_widgets = {'type_adstock':type_adstock, # '1. Comparaison des médias' OU '2. Tous les médias'
                    'validate':validate, # Valider ou bloquer la simulation
                    'nb_date':nb_date_widget}
 
    for i in range(nb_date):
        dict_widgets[f'date_select{i}'] = liste_widgets_date[i] 
        dict_widgets[f'adstock_global{i}'] = liste_widgets_ad[i] 

    for key, val in zip(liste_keys_adstock, liste_widgets_adstock):
        dict_widgets[key] = val
    
    ###### SIMULATEUR ######
    out = interactive_output(choix_media, dict_widgets) 

    box_layout = Layout(flex_flow='line',
                        align_items='stretch',
                        border='2px solid green', # Cadre vert du simulateur
                        width='max-content',
                        justify_content='space-between')

    ###### PARAMETRES ######
    liste_vbox = []
    for i in range(nb_date):
        # Relatif à la date n°i
        hbox = HBox(liste_widgets_adstock[i*10:(i*10)+10])  
        liste_vbox.append(VBox([liste_widgets_date[i], # date de la méthode '1. Comparaison des médias'
                                liste_widgets_ad[i], # adstock de la méthode '1. Comparaison des médias'
                                Label('2. Budget par semaine :'),
                                hbox], layout=box_layout))

    params = VBox([type_adstock] + liste_vbox)

    return AppLayout(left_sidebar=VBox([params, validate]), 
                     right_sidebar=out,
                     grid_gap='10px',
                     pane_widths=[5,1, 2])
      
style = {'description_width': 'initial'}
inter = interact(simulateur, nb_date = widgets.Dropdown(options=list(range(1,11)),
                                                value=1,
                                                description='Nombre de jours',
                                                continuous_update=False,
                                                style=style, layout={'width': 'max-content'})
        )

interactive(children=(Dropdown(description='Nombre de jours', layout=Layout(width='max-content'), options=(1, …