# Overview

The LightweightMMM package (built using Numpyro and JAX) helps advertisers easily build Bayesian MMM models by providing the functionality to appropriately scale data, evaluate models, optimise budget allocations and plot common graphs used in the field.

##### Simplified Model Overview

An MMM quantifies the relationship between media channel activity and sales, while controlling for other factors. A simplified model overview is shown below. An MMM is typically run using weekly level observations (e.g. the KPI could be sales per week), however, it can also be run at the daily level.

![image.png](attachment:image.png)


Where:

1."kpi" is typically the volume or value of sales per time period

2."alpha" is the model intercept

3."trend" is a flexible non-linear function that captures trends in the data

4."seasonality" is a sinusoidal function with configurable parameters that flexibly captures seasonal trends

5."media" is a matrix of different media channel activity (typically impressions or costs per time period) which receives transformations depending on the model used

6."other factore" is a matrix of other factors that could influence sales.

In [1]:
"""

Install Pip dependencies

"""
# First would be to install lightweight_mmm
!pip install --upgrade lightweight_mmm --user

# -- Load older versions of the following libraries
!pip install --upgrade numpyro==0.13.2 --user
!pip install --upgrade jax==0.4.23 --user
!pip install --upgrade jaxlib==0.4.23 --user

!pip install --upgrade memory_profiler --user
!pip install --upgrade worker --user
!pip install --upgrade scipy==1.12.0 --user

!pip install --upgrade prophet --user

Collecting matplotlib==3.6.1 (from lightweight_mmm)
  Using cached matplotlib-3.6.1-cp312-cp312-win_amd64.whl
Collecting seaborn==0.11.1 (from lightweight_mmm)
  Using cached seaborn-0.11.1-py3-none-any.whl.metadata (2.3 kB)
Collecting protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<5.0.0dev,>=3.20.3 (from tensorflow-intel==2.16.1->tensorflow>=2.7.2->lightweight_mmm)
  Using cached protobuf-4.25.3-cp310-abi3-win_amd64.whl.metadata (541 bytes)
Using cached seaborn-0.11.1-py3-none-any.whl (285 kB)
Using cached protobuf-4.25.3-cp310-abi3-win_amd64.whl (413 kB)
Installing collected packages: protobuf, matplotlib, seaborn
Successfully installed matplotlib-3.6.1 protobuf-4.25.3 seaborn-0.11.1
Collecting numpyro==0.13.2
  Using cached numpyro-0.13.2-py3-none-any.whl.metadata (36 kB)
Using cached numpyro-0.13.2-py3-none-any.whl (312 kB)
Installing collected packages: numpyro
Successfully installed numpyro-0.13.2
Collecting jax==0.4.23
  Using cached jax-0.4.23-py3-none-any.whl.m

ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
shap 0.45.0 requires slicer==0.0.7, but you have slicer 0.0.8 which is incompatible.




# 1.Import Libraries

In [2]:
"""
Import Lightweight MMM Libraries

"""

# Import jax.numpy and any other library we might need.
import jax.numpy as jnp
import numpyro

# Import the relevant modules of the library
from lightweight_mmm import lightweight_mmm
from lightweight_mmm import optimize_media
#from lightweight_mmm import plot
from lightweight_mmm import preprocessing
from lightweight_mmm import utils


# importing the library
from memory_profiler import profile
     

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
"""
Import required libraries

"""

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.model_selection import TimeSeriesSplit
from sklearn import metrics
from sklearn.linear_model import LinearRegression
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.linear_model import SGDClassifier
from sklearn.experimental import enable_iterative_imputer
from sklearn.impute import IterativeImputer
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import RepeatedStratifiedKFold
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import GridSearchCV,RandomizedSearchCV,cross_val_score
from sklearn.svm import SVC
from sklearn.neighbors import KNeighborsClassifier
from sklearn.pipeline import Pipeline
from sklearn import linear_model
from sklearn.preprocessing import LabelEncoder,OrdinalEncoder
from sklearn.preprocessing import OneHotEncoder
from sklearn.compose import make_column_transformer
from sklearn.metrics import confusion_matrix, f1_score, accuracy_score,precision_score,recall_score
from sklearn.metrics import mean_absolute_percentage_error,r2_score,mean_squared_error
from sklearn.metrics import RocCurveDisplay,roc_curve,auc
from xgboost import XGBClassifier,XGBRegressor
from sklearn.svm import SVR
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import MinMaxScaler
from sklearn.linear_model import Ridge
from sklearn.linear_model import RidgeCV
from sklearn.model_selection import RepeatedKFold
import time
from sklearn.impute import KNNImputer
from sklearn.naive_bayes import GaussianNB

from sklearn.datasets import make_blobs
from sklearn.datasets import make_moons
from sklearn.datasets import make_classification
import re

import shap
from pprint import pprint

# For ordinal encoding categorical variables, splitting data
from tqdm import tqdm

#Import 'scope' from hyperopt in order to obtain int values for certain hyperparameters.
from hyperopt.pyll.base import scope
from hyperopt import fmin, tpe, hp, STATUS_OK, Trials
import warnings

import arviz

from IPython.display import display, HTML
display(HTML("<st yle>.container { width:80% !important; }</style>"))

pd.set_option('display.max_rows', 300)
pd.set_option('display.max_columns', 300)

import warnings
warnings.filterwarnings('ignore')

ModuleNotFoundError: No module named 'matplotlib.tri._triangulation'

# 2.Define Functions

In [None]:
class data_pre_processing():
    """
    Pre process pandas dataframes
    """
    
    def df_description(df):
        print('shape : ',df.shape)
        print('dtypes :','\n',df.dtypes)
        return None

    def convert_col_names_to_lower_case(df):
        for col in df.columns:
            df = df.rename(columns = {col:col.lower()})
        return df
    
    def describe_data(df):
        display(df.describe(include='all'))
        return None
    
    def data_info(df):
        display(df.info(verbose=True, show_counts=True))
        return None

    def nulls_in_data(df):
        print(df.isnull().sum())
        return None
    
    def duplicate_rows_at_primary_key_level(df,level_of_the_data :list):
        
        df = df[df.duplicated(subset=level_of_the_data, keep=False)]
        if df.shape[0] >0:
            print(" Dataframe has duplicates, total rows with duplicates :",df.shape[0])
        else:
            print(" No duplicates in data")      
        return df
    
    def percentage_nulls_in_each_col(df):
        """
        Return a DF of total null percentage in each column
        Args:
            df (Dataframe): Input DF
        """
        NA = pd.DataFrame(data=[df.isna().sum().tolist(), ["{:.2f}".format(i)+'%' \
           for i in (df.isna().sum()/df.shape[0]*100).tolist()],df.dtypes.tolist()], 
           columns=df.columns, index=['NA Count', 'NA Percent','Dtypes']).transpose()
        display(NA)
        
        return None

        
            
    def drop_cols(df,column_list_to_drop):
        """

        Args:
            df (Dataframe): Input Dataframe
            column_list_to_drop (List of Columns): List of columns to be dropped from DF
            
        Result :
            df with updated columns post dropping.
        """
        
        return df.drop(columns = column_list_to_drop)
    
    def df_columns_and_dtypes_into_list(df) -> dict:
        """

        Returns Dataframe columns and respcetive dtypes as a dictionary.
        Args:
            df (Dataframe): Input Dataframe

        Returns:
            dict: {col1:dtype,col2:dtype2}
        """
        
        return {df.columns.tolist()[i] : str(df.dtypes.tolist()[i]) for i in range(len(df.columns)) }
    
    
    def convert_dtypes_from_dict(df, col_to_dtype_mapping_dict : dict ):
        """
        df : Input Dataframe
        col_to_dtype_mapping_dict: {column:dtype}
        
        return : df with updated dtypes
        
        """
        print('Converting Column Datatypes')
        for col,dtype in col_to_dtype_mapping_dict.items():
            try:
                if col in df.columns:
                    if dtype == 'string':
                        df[col] = df[col].astype(str)
                    elif dtype == 'int':
                        df[col] = df[col].astype(int)
                    elif dtype == 'float':
                        df[col] = df[col].astype(float)
                    elif dtype == 'datetime':
                        df[col] = pd.to_datetime(df[col])
            except Exception as e:
                print('Error processing : ',col, ' ; ',e)
        return df
    


    def ydata_profiling(df):
        """
        Run ydata_profile library for data
        Args:
            df (Dataframe): Df to be profiled

        Returns:
            _type_: _description_
        """
        # !pip install ydata-profiling --ignore-installed llvmlite --user
        # from ydata_profiling import ProfileReport
        # or
        # !pip install ydata-profiling==4.1.2
        # !pip install pydantic==2.6.0 --user
        from ydata_profiling import ProfileReport
        profile = ProfileReport(df)
        
        return profile
    


    def process_nulls_in_df(df,null_columns)-> pd.core.frame.DataFrame:
        """_summary_

        Args:
            df (_type_): _description_
            null_columns (_type_): _description_

        Returns:
            pd.core.frame.DataFrame: _description_
        """
    

        
        return None
    
def plot_line_charts_with_lags(df, x_variable, y_variable='sales'):
    """
    Creates a 1x3 matrix of line charts with the specified x variable and sales over time,
    including charts with 1-week and 2-week lags for the x variable, all scaled based on Z-scores.

    Args:
    - df (DataFrame): The dataset containing the variables.
    - x_variable (str): The name of the x variable to plot against time.
    - y_variable (str): The name of the y variable, default is 'sales'.
    """

    # Ensuring the date column is in datetime format
    df['week_start_date'] = pd.to_datetime(df['week_start_date'])

    # Creating lagged variables
    df[f'{x_variable}_lag1'] = df[x_variable].shift(-1)
    df[f'{x_variable}_lag2'] = df[x_variable].shift(-2)

    # Scaling the variables based on Z-scores
    scaler = StandardScaler()
    scaled_vars = scaler.fit_transform(df[[y_variable, x_variable, f'{x_variable}_lag1', f'{x_variable}_lag2']].dropna())
    df_scaled = pd.DataFrame(scaled_vars, columns=[y_variable, x_variable, f'{x_variable}_lag1', f'{x_variable}_lag2'])

    # Setting up the figure for plotting
    fig, axes = plt.subplots(1, 3, figsize=(18, 6), sharey=True)

    # Because scaling and shifting may lead to a different number of rows, we need to align the date index.
    dates = df['week_start_date'][-len(df_scaled):]

    # Original variable line chart
    axes[0].plot(dates, df_scaled[y_variable], label=y_variable, color='blue')
    axes[0].plot(dates, df_scaled[x_variable], label=x_variable, color='red')
    axes[0].set_title(f'Original {y_variable} and {x_variable} (Scaled)')
    axes[0].legend(loc='upper left')

    # 1-week lag line chart
    axes[1].plot(dates, df_scaled[y_variable], label=y_variable, color='blue')
    axes[1].plot(dates, df_scaled[f'{x_variable}_lag1'], label=f'{x_variable} (1-week lag)', color='red')
    axes[1].set_title(f'{y_variable} and 1-week Lag of {x_variable} (Scaled)')
    axes[1].legend(loc='upper left')

    # 2-week lag line chart
    axes[2].plot(dates, df_scaled[y_variable], label=y_variable, color='blue')
    axes[2].plot(dates, df_scaled[f'{x_variable}_lag2'], label=f'{x_variable} (2-week lag)', color='red')
    axes[2].set_title(f'{y_variable} and 2-week Lag of {x_variable} (Scaled)')
    axes[2].legend(loc='upper left')

    
    for ax in axes:
        ax.set_xlabel('Date')
        ax.set_ylabel('Scaled Value')
        ax.tick_params(axis='x', rotation=45)

    plt.tight_layout()
    plt.show()

# 3.Pre-process Data

In [None]:
"""
Read the dataset.

Four years' (209 weeks) records of sales, media impression and media spending at weekly level.

1. Media Variables

Media Impression (prefix='mdip_'): impressions of 13 media channels: direct mail, insert, newspaper, digital audio, radio, TV, digital video, social media, online display, email, SMS, affiliates, SEM.
Media Spending (prefix='mdsp_'): spending of media channels.
2. Control Variables

Macro Economy (prefix='me_'): CPI, gas price.
Markdown (prefix='mrkdn_'): markdown/discount.
Store Count ('st_ct')
Retail Holidays (prefix='hldy_'): one-hot encoded.
Seasonality (prefix='seas_'): month, with Nov and Dec further broken into to weeks. One-hot encoded.
3. Sales Variable ('sales')

"""
    



mmm_df = pd.read_csv('MMM_Data.csv')
print(mmm_df.shape)



In [None]:
"""
Define new names for columns.
Re-map column names.

"""

mmm_df2 = mmm_df.copy()

mapped_columns_dict = {'mdip_':'media_impression_',\
                        'mdsp_':'media_spend_',\
                        'me_': 'macro_econ_',\
                        'mrkdn_':'markdown_discount_',\
                        'st_ct':'store_count',\
                        'wk_strt_dt':'week_start_date',\
                        'wk_in_yr_nbr':'week_in_year',\
                        'yr_nbr':'year',\
                        'qtr_nbr':'quarter',\
                        'prd':'period',\
                        'wk_nbr':'week_in_month',\
                        '_dm':'_direct_mail',\
                        '_inst':'_insert',\
                        '_nsp':'_newspaper',\
                        '_auddig':'_dig_audio',\
                        '_audtr':'_radio',\
                        '_vidtr':'_tv',\
                        '_viddig':'_dig_video',\
                        '_so':'_social_media',\
                        '_on':'_online display',\
                        '_em':'_email',\
                        '_sms':'_sms',\
                        }

datatype_mapping = {'week_start_date':'datetime'}

# Map column names
for short_form,long_form in mapped_columns_dict.items():
    for col in mmm_df2.columns:
        if short_form in col:
            new_col_name = re.sub(short_form,long_form,col)
            # print(short_form,"--> ",col,"--> ",new_col_name)
            mmm_df2 = mmm_df2.rename(columns={col:new_col_name})


# Lowercase column names
mmm_df2 = data_pre_processing.convert_col_names_to_lower_case(mmm_df2)

# String to Date conversion
mmm_df2 = data_pre_processing.convert_dtypes_from_dict(df = mmm_df2,col_to_dtype_mapping_dict=datatype_mapping)





In [None]:
"""
Define continuous,categorical and target column

"""


continuous_column = [ 'media_impression_direct_mail', 'media_impression_insert',
       'media_impression_newspaper', 'media_impression_dig_audio',
       'media_impression_radio', 'media_impression_tv',
       'media_impression_dig_video', 'media_impression_social_media',
       'media_impression_online display', 'media_impression_email',
       'media_impression_sms', 'media_impression_aff', 'media_impression_sem',
       'media_spend_direct_mail', 'media_spend_insert',
       'media_spend_newspaper', 'media_spend_dig_audio', 'media_spend_radio',
       'media_spend_tv', 'media_spend_dig_video', 'media_spend_social_media',
       'media_spend_online display', 'media_spend_sem',
       'macro_econ_ics_all', 'macro_econ_gas_dpg', 'store_count',
       'markdown_discount_valadd_edw', 'markdown_discount_pdm']

one_hot_encoded_columns = ["hldy_black friday",
       "hldy_christmas day", "hldy_christmas eve", "hldy_columbus day",
       "hldy_cyber monday", "hldy_day after christmas", "hldy_easter",
       "hldy_father's day", "hldy_green monday", "hldy_july 4th",
       "hldy_labor day", "hldy_mlk", "hldy_memorial day", "hldy_mother's day",
       "hldy_nye", "hldy_new year's day", "hldy_pre thanksgiving",
       "hldy_presidents day", "hldy_prime day", "hldy_thanksgiving",
       "hldy_valentine's day", "hldy_veterans day",
       #  "seas_period_1",
       # "seas_period_2", "seas_period_3", "seas_period_4", "seas_period_5",
       # "seas_period_6", "seas_period_7", "seas_period_8", "seas_period_9",
       # "seas_period_12", "seas_week_40", "seas_week_41", "seas_week_42",
       # "seas_week_43", "seas_week_44", "seas_week_45", "seas_week_46",
       # "seas_week_47", "seas_week_48"
       ]

date_columns = ['week_start_date', 'year', 'quarter', 'period', 'week_in_month','week_in_year']

target_column = ["sales"]

""""

Lightweight mmm columns

"""

media_spend_cols = [ 'media_spend_direct_mail', 'media_spend_insert',
       'media_spend_newspaper', 'media_spend_dig_audio', 'media_spend_radio',
       'media_spend_tv', 'media_spend_dig_video', 'media_spend_social_media',
       'media_spend_online display', 'media_spend_sem']


media_impression_cols = [ 'media_impression_direct_mail', 'media_impression_insert',
       'media_impression_newspaper', 'media_impression_dig_audio',
       'media_impression_radio', 'media_impression_tv',
       'media_impression_dig_video', 'media_impression_social_media',
       'media_impression_online display', 'media_impression_sem']


extra_features_cols = ['macro_econ_ics_all', 'macro_econ_gas_dpg', 'store_count',
       'markdown_discount_valadd_edw', 'markdown_discount_pdm']+one_hot_encoded_columns

# 4.Train-Test Split

In [None]:
"""
Train test split - Use time-series split as marketing spend data is Longitudnal in nature. 

"""


"""

Time-series Split -->

"""
tss = TimeSeriesSplit(n_splits = 4)

mmm_df2_time_series = mmm_df2.copy()
mmm_df2_time_series.set_index('week_start_date', inplace=True)
mmm_df2_time_series.sort_index(inplace=True)

independant_columns = [i for i in continuous_column+one_hot_encoded_columns+date_columns if 'week_start_date' not in i]
X = mmm_df2_time_series[independant_columns]
y = mmm_df2_time_series[target_column]

# Add 0.01 to each numerical value for lightweight mmm
for col in X.columns:
    if col in continuous_column:
        X[col] = X[col] + 0.1

for train_index, test_index in tss.split(X):
    X_train, X_test = X.iloc[train_index, :], X.iloc[test_index,:]
    y_train, y_test = y.iloc[train_index], y.iloc[test_index]


print(X_train.shape,X_test.shape)




"""

Regular Split -->

"""

# independant_columns = [i for i in continuous_column+one_hot_encoded_columns+date_columns if 'week_start_date' not in i ]
# X = mmm_df2[independant_columns]
# y = mmm_df2[target_column]

# # Using the train test split function
# X_train, X_test, y_train, y_test = train_test_split(
#   X,y , random_state=104,test_size=0.25, shuffle=True)

# print(X_train.shape,X_test.shape)

#### 4.2 Splitting data into Media,Control and Sales variables. Applying scaler functions.

In [None]:
"""
Lightweight mmm implementation:

1) Split the data into Media Impression, Media Spend and Control Variables(Extra). This is a requirement for the lightweight mmm model.

"""


media_impression_train =  X_train[media_impression_cols].to_numpy()

media_impression_test = X_test[media_impression_cols].to_numpy()

target_train = y_train[target_column].sum(axis = 1).to_numpy()

target_test = y_test[target_column].sum(axis = 1).to_numpy()

media_spend_train =  X_train[media_spend_cols].to_numpy()

media_spend_test = X_test[media_spend_cols].to_numpy()

extra_train =  X_train[extra_features_cols].to_numpy()

extra_test = X_test[extra_features_cols].to_numpy()



# Cost is used to assume prior in later steps
costs = X_train[media_spend_cols].sum().to_numpy()

print(costs)



"""

Scaling: Use in-built lightweight mmm scaler to scale the columns. Alternate scalers can be used as well.

"""



media_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean)
extra_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean)
target_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean)
costs_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean, multiply_by=0.15)


#CHANGE THIS-
media_spend_train_scaled = media_scaler.fit_transform(media_spend_train)


extra_train_scaled = extra_scaler.fit_transform(extra_train)
target_train_scaled = target_scaler.fit_transform(target_train)
media_spend_scaled = costs_scaler.fit_transform(costs)


# Test data
extra_test_scaled = extra_scaler.transform(extra_test)
media_spend_test_scaled = media_scaler.fit_transform(media_spend_test)

target_test_scaled =  target_scaler.fit_transform(target_test)


for i in [media_spend_train_scaled,extra_train_scaled,target_train_scaled,media_spend_scaled]:
    print(type(i),i.shape)

print(media_spend_train_scaled.shape[1],  len(media_spend_scaled))

#### Media Saturation and Lagging

It is likely that the effect of a media channel on sales could have a lagged effect which tapers off slowly over time. There are three different approaches to capture this. We compare all three approaches and use the approach that works the best. The approach that works the best will typically be the one which has the best out-of-sample fit (which is one of the generated outputs). The functional forms of these three approaches are briefly described below.

Adstock: Applies an infinite lag that decreases its weight as time passes.

![image-2.png](attachment:image-2.png)

Hill-Adstock: Applies a sigmoid like function for diminishing returns to the output of the adstock function.

![image.png](attachment:image.png)


Carryover: Applies a causal convolution giving more weight to the near values than distant ones.

![image-3.png](attachment:image-3.png)



In [None]:
"""
Test for best ad-stock model

Error1: ValueError: The number of data channels provided must match the number of cost values.

"""

%time



adstock_models = ["adstock", "hill_adstock", "carryover"]
degrees_season = [1,2,3]



model_output_dict = {}

# try:
for model_name in adstock_models:
    for degrees in degrees_season:
        mmm = lightweight_mmm.LightweightMMM(model_name=model_name)
        mmm.fit(media=media_spend_train_scaled,
                media_prior=media_spend_scaled,
                target=target_train_scaled,
                extra_features=extra_train_scaled,
                number_warmup=100,
                number_samples=130,
                number_chains=10,
                degrees_seasonality=degrees,
                weekday_seasonality=False,
                seasonality_frequency=52,
                seed=1)
        
        mmm.print_summary()
        prediction = mmm.predict(
        media=media_spend_test_scaled,
        extra_features=extra_test_scaled,
        target_scaler=target_scaler)
        print("prediction.shape",prediction.shape)
        p = prediction.mean(axis=0)

        mape = mean_absolute_percentage_error(target_test, p)
        r2 = arviz.r2_score(target_test, p)
        # r2 = r2_score(target_test, p)
        print(f"model_name={model_name} degrees={degrees} MAPE={mape} R2={r2}")

        model_output_dict[model_name+"_"+str(degrees)] = (mape,r2)

            
# except Exception as e:
#     print(e)
print(model_output_dict)   


In [None]:
"""
Analyse all models: All their accuracies (Mape) and R2 values are similar.

Carryover model with degree 3 has the highest accuracy in terms of both MAPE and R2. We will go ahead with this.
"""

print(model_output_dict)

mmm = lightweight_mmm.LightweightMMM(model_name="carryover")
mmm.fit(media=media_spend_train_scaled,
        media_prior=media_spend_scaled,
        target=target_train_scaled,
        extra_features=extra_train_scaled,
        number_warmup=100,
        number_samples=130,
        number_chains=10,
        degrees_seasonality=3,
        weekday_seasonality=False,
        seasonality_frequency=52,
        seed=1)

mmm.print_summary()
prediction = mmm.predict(
media=media_spend_test_scaled,
extra_features=extra_test_scaled,
target_scaler=target_scaler)
print("prediction.shape",prediction.shape)
p = prediction.mean(axis=0)

mape = mean_absolute_percentage_error(target_test, p)
r2 = arviz.r2_score(target_test, p)
# r2 = r2_score(target_test, p)
print(f"model_name={model_name} degrees={degrees} MAPE={mape} R2={r2}")


In [None]:
{'channel_'+str(i):media_spend_cols[i] for i in range(len(media_spend_cols))}

### Select best model and run diagnostics

In [None]:
plot.plot_model_fit(mmm, target_scaler=target_scaler)

In [None]:
new_predictions = mmm.predict(media=media_spend_test_scaled,
                              extra_features=extra_test_scaled,
                            #   target_scaler=target_scaler,
                              seed=1)

plot.plot_out_of_sample_model_fit(out_of_sample_predictions=new_predictions,
                                 out_of_sample_target=target_scaler.transform(target_test))

In [None]:
def apply_adstock(x, L, P, D):
    '''
    params:
    x: original media variable, array
    L: length
    P: peak, delay in effect
    D: decay, retain rate
    returns:
    array, adstocked media variable
    '''
    x = np.append(np.zeros(L-1), x)
    
    weights = np.zeros(L)
    for l in range(L):
        weight = D**((l-P)**2)
        weights[L-1-l] = weight
    
    adstocked_x = []
    for i in range(L-1, len(x)):
        x_array = x[i-L+1:i+1]
        xi = sum(x_array * weights)/sum(weights)
        adstocked_x.append(xi)
    adstocked_x = np.array(adstocked_x)
    return adstocked_x



In [None]:
# L = 12
# D = 0.54
# P = 1.01 

# x = x_plot
# carryover_tv = apply_adstock(x, L, P, D)


# plt.plot(x_plot)
# plt.plot(carryover_tv)
# plt.show()

### Plotting Prior and Posterior Distributions

In [None]:
plot.plot_prior_and_posterior(media_mix_model=mmm)

In [None]:
"""
Bug- to be fixed
"""

# plot.plot_media_channel_posteriors(media_mix_model=mmm,channel_names=media_spend_cols)

### Baseline-contribution chart across the time periods

In [None]:
media_contribution, roi_hat = mmm.get_posterior_metrics(target_scaler=target_scaler, cost_scaler=costs_scaler)
plot.plot_media_baseline_contribution_area_plot(media_mix_model=mmm,
                                                target_scaler=target_scaler,
                                                fig_size=(30,10),
                                                channel_names = media_spend_cols
                                                )

### Media contribution and ROI contribution per channel

In [None]:
plt.figure(figsize=(20,20))
plot.plot_bars_media_metrics(metric=media_contribution, metric_name="Media Contribution", channel_names=media_spend_cols,interval_mid_range= 0.9)

In [None]:
plt.figure(figsize=(20,20))
plot.plot_bars_media_metrics(metric=roi_hat, metric_name="ROI hat", channel_names=media_spend_cols,interval_mid_range= 0.9)

### Response Curves

In [None]:
plot.plot_response_curves(media_mix_model=mmm, target_scaler=target_scaler,seed = 1)

# Budget Optimization - Find the best media allocation based on MMM model, prices and a budget.

In [None]:
""" 
Define the number of periods for which we want to run optimization

"""

n_media_channels = 10
prices = jnp.ones(mmm.n_media_channels)
n_time_periods = 5
media_data = X[media_spend_cols].to_numpy()

budget = jnp.sum(jnp.dot(prices, media_data.mean(axis=0)))* n_time_periods

In [None]:
# Run optimization with the parameters of choice.
solution, kpi_without_optim, previous_media_allocation = optimize_media.find_optimal_budgets(
    n_time_periods=n_time_periods,
    media_mix_model=mmm,
    extra_features= extra_scaler.transform(X[extra_features_cols].to_numpy())[:n_time_periods],
    budget=budget,
    prices=prices,
    media_scaler=media_scaler,
    target_scaler=target_scaler,
    seed=1)

In [None]:
# Obtain the optimal weekly allocation.
optimal_buget_allocation = prices * solution.x
optimal_buget_allocation

In [None]:

# similar renormalization to get previous budget allocation
previous_budget_allocation = prices * previous_media_allocation
previous_budget_allocation

In [None]:
# Both these values should be very close in order to compare KPI
budget, optimal_buget_allocation.sum()

In [None]:
# Both numbers should be almost equal
budget, jnp.sum(solution.x * prices)

In [None]:
{'channel_'+str(i):media_spend_cols[i] for i in range(len(media_spend_cols))}

In [None]:
# Plot out pre post optimization budget allocation and predicted target variable comparison.
plot.plot_pre_post_budget_allocation_comparison(media_mix_model=mmm, 
                                                kpi_with_optim=solution['fun'], 
                                                kpi_without_optim=kpi_without_optim,
                                                optimal_buget_allocation=optimal_buget_allocation, 
                                                previous_budget_allocation=previous_budget_allocation, 
                                                figure_size=(10,10))

## Save Model

In [None]:
# file_path = "mmm_carryover_2_12_apr.pkl"
# utils.save_model(media_mix_model=mmm, file_path=file_path)

In [None]:
# file_path = "mmm_adstock_2_21_apr.pkl"
# utils.save_model(media_mix_model=mmm, file_path=file_path)