In [2]:
# %%
import pandas as pd
import numpy as np

import jax.numpy as jnp
from sklearn.metrics import mean_absolute_percentage_error

from lightweight_mmm import preprocessing

from lightweight_mmm import utils
from lightweight_mmm import lightweight_mmm
from lightweight_mmm import plot
from lightweight_mmm import optimize_media
from lightweight_mmm import utils


In [4]:
# %%
sim_data = pd.read_csv("../input/simulated_data.csv")
sim_data = sim_data.dropna(axis=0)


In [5]:
# %% config: set variables
spend_columns = ['spend_TV', 'spend_Facebook', 'spend_Search']
media_columns = ['impressions_TV', 'impressions_Facebook', 
                 'clicks_Search']
sales = ['total_revenue']

spend_data = sim_data[spend_columns]
media_data = sim_data[media_columns]
sales_data = sim_data[sales]


In [6]:
# %% convert jax format
cost_jax = jnp.array(
    spend_data.values
)

media_data_jax = jnp.array(
    media_data.values
)

sales_jax = jnp.array(
    sales_data.values
)


In [7]:
# %% train test split
split_point = len(media_data_jax) - 20

media_data_train = media_data_jax[:split_point]
media_data_test = media_data_jax[split_point:]

target_train = sales_jax[:split_point]
target_test = sales_jax[split_point:]

costs_train = cost_jax[:split_point].sum(axis=0)
costs_test = cost_jax[split_point:].sum(axis=0)
media_names = media_data.columns


In [8]:
# %% scaling
media_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean)
target_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean)
costs_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean)

media_data_train = media_scaler.fit_transform(media_data_train)
media_data_test = media_scaler.transform(media_data_test)

target_train = target_scaler.fit_transform(target_train)
target_test = target_scaler.transform(target_test)

costs_train = costs_scaler.fit_transform(costs_train)
cost_test = costs_scaler.transform(costs_test)


In [9]:
%%time
mmm = utils.load_model("../output/lightweight_mmm_20240723_model.pkl")

CPU times: total: 62.5 ms
Wall time: 588 ms
CPU times: total: 62.5 ms
Wall time: 588 ms


In [10]:
#%%
pred = mmm.predict(
    media=media_data_test,
    target_scaler=target_scaler
)


In [11]:
# %% 
mmm.print_summary()



                                         mean       std    median      5.0%     95.0%     n_eff     r_hat
                      coef_media[0]      0.31      0.28      0.23      0.00      0.69   3089.99      1.00
                      coef_media[1]      0.56      0.57      0.37      0.00      1.35   2169.94      1.00
                      coef_media[2]      0.40      0.38      0.29      0.00      0.92   3085.64      1.00
                      coef_trend[0]     -0.03      0.03     -0.02     -0.07      0.01    715.43      1.00
                         expo_trend      0.72      0.17      0.68      0.50      0.96    959.21      1.01
             gamma_seasonality[0,0]      0.24      0.74      0.23     -0.96      1.43    844.29      1.01
             gamma_seasonality[0,1]     -0.33      0.47     -0.31     -1.07      0.46   1183.54      1.01
             gamma_seasonality[1,0]      0.04      0.16      0.03     -0.24      0.30    973.14      1.01
             gamma_seasonality[1,1]      0.07

In [12]:
# %%
media_effect_hat, roi_hat = mmm.get_posterior_metrics()


In [13]:
# %%
lambda_table = pd.DataFrame(mmm.trace['lag_weight']._value, columns=mmm.media_names)
# When I used lightweight MMM on this dataset, 
# I assumed mmm.trace['lag_weight'] would recover my lambdas. Instead, 
# I am getting these (relatively) very large values:
lambda_table.apply(np.mean, axis=0)


TV          0.678541
Facebook    0.686391
Search      0.681990
dtype: float32

In [14]:
# %% I have here TypeError
plot.plot_model_fit(mmm, target_scaler=target_scaler)


TypeError: sub got incompatible shapes for broadcasting: (84, 1), (50000, 84).