# Data input structure

## Each row represents
- One **Brand**
- In one **Country**
- At one **Month** (Year + Month)

---

## Inputs to the model (`X`)
Each row must provide:

1. **Time features**
   - `time_idx` (int): monotonic month index (0, 1, 2, …).
   - `month_sin` (float): `sin(2π * month/12)`.
   - `month_cos` (float): `cos(2π * month/12)`.

2. **Categorical features**
   - `country_id` (int): encoded Country.
   - `brand_id` (int): encoded Brand.

3. **Controllable features (vector)**
   - Shape: `(n_controls,)`.
   - Example (if n_controls=10):
     `[promotions_index, media_invest, discount_pct, inflation, temperature_c, holiday_flag, Meaning, Difference, Salience, Premium]`
   - Values must be standardized (mean=0, std=1).
   - At **forecast time**: feed a vector of zeros.

---

## Extra data (not part of model inputs but needed in training)
- `group_id` (int): unique ID for each (Country × Date).  
  - All brands in the same Country+Date share the same `group_id`.  
  - Used in loss to apply softmax within groups.

- `y_true` (float): actual Power percentage for that row (0–100).  
  - During training: available.  
  - During forecast: unknown (to be predicted).

---

## Tensor shapes during training
- `time_idx`: `(batch, 1)`
- `month_sin`: `(batch, 1)`
- `month_cos`: `(batch, 1)`
- `country_id`: `(batch, 1)`
- `brand_id`: `(batch, 1)`
- `controls`: `(batch, n_controls)`
- `y_true`: `(batch,)`
- `group_id`: `(batch,)`

---

## Model outputs
- `logits`: `(batch, 1)` — unnormalized additive scores per row.
- `ctrl_out`: `(batch, 1)` — controllable contribution (for regularization).

---

## Loss function
- Uses `(logits, group_id)` to compute **softmax per group** (so all brands in same Country×Date sum to 1).
- Compares predicted fractions × 100 to `y_true` (MSE).
- Adds small L2 penalty on `ctrl_out`.


In [1]:
import pandas as pd
import numpy as np
import pickle as pkl

# BG Data

In [32]:
# import files for bg data
input_path = "C:/Users/40107904/OneDrive - Anheuser-Busch InBev/ABI/WORK/hackathon_power/hackathon_lt_equity/dummy_data"
bg_data = pd.read_csv(f"{input_path}/bg.csv")

bg_data = bg_data[bg_data['PERIOD_TYPE'] == 'R12M']
bg_data["year"] = bg_data["LABEL_PERIOD"].str.slice(0, 4).astype(int)
bg_data["month"] = bg_data["LABEL_PERIOD"].str.slice(4, 8).str.upper().str.strip()

month_map = {
    'JAN': 1,
    'FEB': 2,
    'MAR': 3,
    'APR': 4,
    'MAY': 5,
    'JUN': 6,
    'JUL': 7,
    'AUG': 8,
    'SEP': 9,
    'OCT': 10,
    'NOV': 11,
    'DEC': 12,
    'YTD': 12
}

bg_data["month"] = bg_data["month"].apply(lambda x: month_map[x] if x in month_map else 12)
bg_data = bg_data.drop(columns=["LABEL_PERIOD", "PERIOD_TYPE", "COHORT", "COHORT_NAME"])
bg_data = bg_data.rename(columns={"COUNTRY_CODE": "country", "BRAND_DESC": "brand"})
bg_data.columns = bg_data.columns.str.lower()
bg_data = bg_data.sort_values(by=["country", "brand", "year", "month"])
bg_data = bg_data.reset_index(drop=True)
bg_numeric_cols = []
for col in bg_data.select_dtypes(include=['float64', 'int64']).columns:
    if col not in ['year', 'month']:
        bg_numeric_cols.append(col)
bg_data = bg_data[[col for col in bg_data.columns if col not in bg_numeric_cols] + ['power']]
bg_data['country'] = 'brazil'
bg_data['brand'] = bg_data['brand'].str.lower().str.strip()
bg_data

Unnamed: 0,country,brand,year,month,power
0,brazil,amstel,2021,3,3.1
1,brazil,amstel,2021,4,3.1
2,brazil,amstel,2021,5,3.1
3,brazil,amstel,2021,6,3.1
4,brazil,amstel,2021,7,3.2
...,...,...,...,...,...
1758,brazil,tiger,2024,3,0.4
1759,brazil,tiger,2024,4,0.4
1760,brazil,tiger,2024,5,0.4
1761,brazil,tiger,2024,6,0.4


In [33]:
bg_data = bg_data[bg_data['brand'].isin(['becks', 'brahma', 'budweiser', 'corona', 'skol', 'spaten', 'stella artois'])].reset_index(drop=True)

# for each year-month, the sum of power for all brands should be 100
bg_data['power'] = bg_data['power'] / bg_data.groupby(['year', 'month'])['power'].transform('sum') * 100
# round the power to 1 decimal places
bg_data['power'] = bg_data['power'].round(1)

bg_data

Unnamed: 0,country,brand,year,month,power
0,brazil,brahma,2021,3,52.0
1,brazil,brahma,2021,4,52.2
2,brazil,brahma,2021,5,52.2
3,brazil,brahma,2021,6,52.4
4,brazil,brahma,2021,7,52.2
...,...,...,...,...,...
200,brazil,stella artois,2024,3,11.6
201,brazil,stella artois,2024,4,11.5
202,brazil,stella artois,2024,5,11.5
203,brazil,stella artois,2024,6,11.5


In [34]:
bg_data_backup = bg_data.copy()

# Weather data

In [4]:
# import files for weather data
weather_data = pd.read_csv(f"{input_path}/mroi/brazil_weather_data.csv")

weather_data['timedesc'] = pd.to_datetime(weather_data['timedesc'], format='%Y-%m-%d')
weather_data['year'] = weather_data['timedesc'].dt.year
weather_data['month'] = weather_data['timedesc'].dt.month
weather_data = weather_data.drop(columns=['timedesc'])
weather_data = weather_data[[x for x in weather_data.columns if x != 'state']].groupby(['country','year', 'month']).mean().reset_index()
weather_data = weather_data.sort_values(by=['country', 'year', 'month']).reset_index(drop=True)
weather_data

Unnamed: 0,country,year,month,avgtemp,maxtemp,mintemp,prcp
0,brazil,2007,1,79.966611,94.850000,67.770536,0.226270
1,brazil,2007,2,78.827847,93.760714,66.831250,0.288651
2,brazil,2007,3,79.015594,93.469286,66.038571,0.192972
3,brazil,2007,4,77.859795,92.122321,64.649107,0.229469
4,brazil,2007,5,75.079357,89.963393,58.626786,0.139833
...,...,...,...,...,...,...,...
214,brazil,2024,11,80.000545,96.285714,65.895000,0.162963
215,brazil,2024,12,78.036464,92.826786,65.503571,0.269948
216,brazil,2025,1,79.103295,93.280357,67.039286,0.309204
217,brazil,2025,2,79.223506,93.354464,67.253571,0.277279


# Macroeco data

In [5]:
# import files for macroeconomic data
macro_data = pd.read_csv(f"{input_path}/mroi/brazil_macroeconomics_data.csv")

macro_data['timedesc'] = pd.to_datetime(macro_data['timedesc'], format='%Y-%m-%d')
macro_data['year'] = macro_data['timedesc'].dt.year
macro_data['month'] = macro_data['timedesc'].dt.month
macro_data = macro_data.drop(columns=['timedesc'])
macro_data = macro_data.groupby(['country','year', 'month']).mean().reset_index()

macro_data

Unnamed: 0,country,year,month,inflation_rate,unemployment_rate
0,brazil,2015,1,7.137737,7.354279
1,brazil,2015,2,7.701587,7.354279
2,brazil,2015,3,8.128505,7.354279
3,brazil,2015,4,8.171487,7.953818
4,brazil,2015,5,8.472943,7.953818
...,...,...,...,...,...
118,brazil,2024,11,3.715704,7.804767
119,brazil,2024,12,3.566220,7.804767
120,brazil,2025,1,5.228843,6.383003
121,brazil,2025,2,5.310633,6.383003


# Media data

In [6]:
# import files for media data
media_data = pd.read_csv(f"{input_path}/mroi/brazil_weekly_decomps_data.csv")

media_data_backup = media_data.copy()

In [57]:
media_data = media_data_backup.copy()
cols_to_pick = ['country', 'brand', 'vehicle', 'date', 'spend', 'maco']
media_data = media_data[cols_to_pick]


media_data['date'] = pd.to_datetime(media_data['date'], format='%Y-%m-%d')
media_data['year'] = media_data['date'].dt.year
media_data['month'] = media_data['date'].dt.month
media_data = media_data.drop(columns=['date'])

non_digital = ["radio", "paytv", "ooh", "print", "cinema"]
# lable all other vehicles as digital
# media_data['vehicle'] = media_data['vehicle'].apply(lambda x: x if x in non_digital else 'digital')
media_data = media_data.loc[media_data['vehicle'].isin(non_digital)]

media_data = media_data.groupby(['country', 'brand', 'year', 'month']).sum().reset_index()
media_data = media_data.drop(columns=['vehicle'])
media_data = media_data.sort_values(by=['country', 'brand', 'year', 'month']).reset_index(drop=True)


media_data = media_data.drop(columns=['maco'])
# media_data['roi'] = media_data['maco'] / media_data['spend']

media_data

Unnamed: 0,country,brand,year,month,spend
0,brazil,becks,2022,4,10762.870000
1,brazil,becks,2022,6,297087.060000
2,brazil,becks,2022,7,346177.266600
3,brazil,becks,2022,8,9353.539600
4,brazil,becks,2022,10,527002.860000
...,...,...,...,...,...
327,brazil,stella artois,2024,7,128461.584771
328,brazil,stella artois,2024,8,31294.081531
329,brazil,stella artois,2024,9,31395.130560
330,brazil,stella artois,2024,10,19132.004332


In [59]:
# print(media_data['vehicle'].unique())

In [60]:
print(media_data['brand'].value_counts())

brand
brahma           66
budweiser        59
skol             42
stella artois    41
corona           39
spaten           34
guarana          32
becks            19
Name: count, dtype: int64


In [42]:
media_data_sorted_backup = media_data.copy()

# Neilsen

In [18]:
# import files for pos data
pos_data = pd.read_csv(f"{input_path}/mroi/brazil_pos_data.csv")

pos_data = pos_data[pos_data['regiondesc'] == "ambev_total_brazil"]
pos_data['vol*wd'] = pos_data['volume_hl'] * pos_data['wd']
pos_data = pos_data.drop(columns=['regiondesc', 'channeldesc', 'product', 'wd'])
pos_data['timedesc'] = pd.to_datetime(pos_data['timedesc'], format='%Y-%m-%d')
pos_data['year'] = pos_data['timedesc'].dt.year
pos_data['month'] = pos_data['timedesc'].dt.month
pos_data = pos_data.drop(columns=['timedesc'])

pos_data = pos_data.groupby(['country', 'brand', 'year', 'month']).sum().reset_index()

pos_data['wd'] = pos_data['vol*wd'] / pos_data['volume_hl']
pos_data = pos_data.drop(columns=['vol*wd'])
pos_data['price_usd'] = pos_data['sales_usd'] / pos_data['volume_hl']

pos_data

Unnamed: 0,country,brand,year,month,volume_hl,price_usd,sales_usd,wd
0,brazil,becks,2021,4,15434.858220,333.266130,5.143915e+06,46.006752
1,brazil,becks,2021,5,17181.147640,349.999377,6.013391e+06,39.852145
2,brazil,becks,2021,6,17317.632840,362.576849,6.278973e+06,38.497249
3,brazil,becks,2021,7,18390.469200,356.069613,6.548287e+06,40.243897
4,brazil,becks,2021,8,18475.981540,356.806456,6.592349e+06,40.000063
...,...,...,...,...,...,...,...,...
352,brazil,stella artois,2025,2,161553.590000,322.455414,5.209383e+07,45.723218
353,brazil,stella artois,2025,3,181431.400000,319.028870,5.788185e+07,46.700260
354,brazil,stella artois,2025,4,157264.200000,321.187051,5.051122e+07,45.397008
355,brazil,stella artois,2025,5,158956.800000,319.539801,5.079302e+07,46.411804


# Concantenate

In [61]:
final_df = bg_data.copy()
final_df = final_df.merge(weather_data, on=['country', 'year', 'month'], how='left')
final_df = final_df.merge(macro_data, on=['country', 'year', 'month'], how='left')
final_df = final_df.merge(media_data, on=['country', 'brand', 'year', 'month'], how='left')
final_df = final_df.merge(pos_data, on=['country', 'brand', 'year', 'month'], how='left')
final_df

Unnamed: 0,country,brand,year,month,power,time_idx,month_sin,month_cos,avgtemp,maxtemp,mintemp,prcp,inflation_rate,unemployment_rate,spend,volume_hl,price_usd,sales_usd,wd
0,brazil,brahma,2021,3,52.0,0,1.000000,0.000000,78.217462,92.275893,66.104464,0.285617,6.099479,14.590725,,,,,
1,brazil,brahma,2021,4,52.2,1,0.866025,-0.500000,77.396357,91.137500,64.177679,0.186533,6.759304,14.249275,,1.839577e+06,182.123021,3.350294e+08,58.048098
2,brazil,brahma,2021,5,52.2,2,0.500000,-0.866025,75.128937,89.091429,60.326429,0.141963,8.056065,14.249275,190261.500000,1.902111e+06,185.212284,3.522944e+08,58.210640
3,brazil,brahma,2021,6,52.4,3,0.000000,-1.000000,74.388775,88.361607,59.716964,0.103733,8.347072,14.249275,97304.240000,1.792848e+06,185.524685,3.326176e+08,57.777443
4,brazil,brahma,2021,7,52.2,4,-0.500000,-0.866025,72.554099,88.608571,54.508571,0.081499,8.994823,13.101993,164286.666028,1.752503e+06,186.505429,3.268512e+08,58.643422
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
200,brazil,stella artois,2024,3,11.6,36,1.000000,0.000000,80.208747,93.802143,68.008571,0.278326,4.254281,7.966095,15529.292848,1.180883e+05,395.237621,4.667294e+07,44.719785
201,brazil,stella artois,2024,4,11.5,37,0.866025,-0.500000,79.001819,91.968750,66.563393,0.243686,3.919996,8.059917,653401.377130,1.162377e+05,388.451738,4.515273e+07,45.651960
202,brazil,stella artois,2024,5,11.5,38,0.500000,-0.866025,78.316303,91.694643,63.684821,0.175895,3.797215,8.059917,675822.190650,1.210071e+05,386.609308,4.678246e+07,47.091780
203,brazil,stella artois,2024,6,11.5,39,0.000000,-1.000000,75.155107,89.387857,60.156429,0.175970,4.206327,7.944695,117205.445742,1.198886e+05,386.514679,4.633870e+07,47.800273


In [23]:
final_df = final_df.fillna(0)
print(final_df.isna().sum())

country              0
brand                0
year                 0
month                0
power                0
avgtemp              0
maxtemp              0
mintemp              0
prcp                 0
inflation_rate       0
unemployment_rate    0
volume_hl            0
price_usd            0
sales_usd            0
wd                   0
dtype: int64


In [24]:
final_df

Unnamed: 0,country,brand,year,month,power,avgtemp,maxtemp,mintemp,prcp,inflation_rate,unemployment_rate,volume_hl,price_usd,sales_usd,wd
0,brazil,brahma,2021,3,52.0,78.217462,92.275893,66.104464,0.285617,6.099479,14.590725,0.000000e+00,0.000000,0.000000e+00,0.000000
1,brazil,brahma,2021,4,52.2,77.396357,91.137500,64.177679,0.186533,6.759304,14.249275,1.839577e+06,182.123021,3.350294e+08,58.048098
2,brazil,brahma,2021,5,52.2,75.128937,89.091429,60.326429,0.141963,8.056065,14.249275,1.902111e+06,185.212284,3.522944e+08,58.210640
3,brazil,brahma,2021,6,52.4,74.388775,88.361607,59.716964,0.103733,8.347072,14.249275,1.792848e+06,185.524685,3.326176e+08,57.777443
4,brazil,brahma,2021,7,52.2,72.554099,88.608571,54.508571,0.081499,8.994823,13.101993,1.752503e+06,186.505429,3.268512e+08,58.643422
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
200,brazil,stella artois,2024,3,11.6,80.208747,93.802143,68.008571,0.278326,4.254281,7.966095,1.180883e+05,395.237621,4.667294e+07,44.719785
201,brazil,stella artois,2024,4,11.5,79.001819,91.968750,66.563393,0.243686,3.919996,8.059917,1.162377e+05,388.451738,4.515273e+07,45.651960
202,brazil,stella artois,2024,5,11.5,78.316303,91.694643,63.684821,0.175895,3.797215,8.059917,1.210071e+05,386.609308,4.678246e+07,47.091780
203,brazil,stella artois,2024,6,11.5,75.155107,89.387857,60.156429,0.175970,4.206327,7.944695,1.198886e+05,386.514679,4.633870e+07,47.800273


# Model Input Creation

In [25]:
final_df_backup = final_df.copy()

In [26]:
df = final_df_backup.copy()

# time_idx column
df = df.sort_values(by=['country', 'year', 'month', 'brand']).reset_index(drop=True)
df

df = df.sort_values(['year', 'month'])
min_year, min_month = df['year'].min(), df['month'][df['year'] == df['year'].min()].min()

df['time_idx'] = (df['year'] - min_year) * 12 + (df['month'] - min_month)

# 3. Cyclical month encoding
df['month_sin'] = np.sin(2 * np.pi * df['month'] / 12.0).round(6)
df['month_cos'] = np.cos(2 * np.pi * df['month'] / 12.0).round(6)
df

Unnamed: 0,country,brand,year,month,power,avgtemp,maxtemp,mintemp,prcp,inflation_rate,unemployment_rate,volume_hl,price_usd,sales_usd,wd,time_idx,month_sin,month_cos
0,brazil,brahma,2021,3,52.0,78.217462,92.275893,66.104464,0.285617,6.099479,14.590725,0.000000e+00,0.000000,0.000000e+00,0.000000,0,1.0,0.000000
1,brazil,budweiser,2021,3,19.8,78.217462,92.275893,66.104464,0.285617,6.099479,14.590725,0.000000e+00,0.000000,0.000000e+00,0.000000,0,1.0,0.000000
2,brazil,corona,2021,3,11.5,78.217462,92.275893,66.104464,0.285617,6.099479,14.590725,0.000000e+00,0.000000,0.000000e+00,0.000000,0,1.0,0.000000
3,brazil,spaten,2021,3,4.0,78.217462,92.275893,66.104464,0.285617,6.099479,14.590725,0.000000e+00,0.000000,0.000000e+00,0.000000,0,1.0,0.000000
4,brazil,stella artois,2021,3,12.7,78.217462,92.275893,66.104464,0.285617,6.099479,14.590725,0.000000e+00,0.000000,0.000000e+00,0.000000,0,1.0,0.000000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
200,brazil,brahma,2024,7,46.4,73.665021,89.108036,57.716071,0.077463,3.835883,7.933507,1.748808e+06,245.770147,4.298048e+08,49.579925,40,-0.5,-0.866025
201,brazil,budweiser,2024,7,19.0,73.665021,89.108036,57.716071,0.077463,3.835883,7.933507,2.878564e+05,265.411869,7.640051e+07,48.479478,40,-0.5,-0.866025
202,brazil,corona,2024,7,13.5,73.665021,89.108036,57.716071,0.077463,3.835883,7.933507,5.791244e+04,430.635230,2.493914e+07,68.194613,40,-0.5,-0.866025
203,brazil,spaten,2024,7,9.3,73.665021,89.108036,57.716071,0.077463,3.835883,7.933507,1.993254e+05,331.999014,6.617583e+07,56.403063,40,-0.5,-0.866025


In [51]:
media_data = media_data_sorted_backup.copy()

media_data['time_idx'] = (media_data['year'] - min_year) * 12 + (media_data['month'] - min_month)
media_data['month_sin'] = np.sin(2 * np.pi * media_data['month'] / 12.0).round(6)
media_data['month_cos'] = np.cos(2 * np.pi * media_data['month'] / 12.0).round(6)
media_data = media_data.loc[media_data['time_idx'].isin(df['time_idx'].unique())]
media_data = media_data.loc[media_data['brand'].isin(df['brand'].unique())]
media_data = media_data.sort_values(by=['country',  'year', 'month','brand','vehicle']).reset_index(drop=True)

media_data.head(10)

Unnamed: 0,country,brand,vehicle,year,month,spend,time_idx,month_sin,month_cos
0,brazil,brahma,digital,2021,3,2039287.0,0,1.0,0.0
1,brazil,budweiser,digital,2021,3,185150.4,0,1.0,0.0
2,brazil,corona,digital,2021,3,406510.9,0,1.0,0.0
3,brazil,spaten,digital,2021,3,1926.752,0,1.0,0.0
4,brazil,stella artois,digital,2021,3,42831.99,0,1.0,0.0
5,brazil,brahma,digital,2021,4,2585893.0,1,0.866025,-0.5
6,brazil,budweiser,digital,2021,4,71931.36,1,0.866025,-0.5
7,brazil,corona,digital,2021,4,306106.5,1,0.866025,-0.5
8,brazil,spaten,digital,2021,4,883.4646,1,0.866025,-0.5
9,brazil,stella artois,digital,2021,4,68412.69,1,0.866025,-0.5


In [52]:
# from sklearn.preprocessing import StandardScaler
# scaler = StandardScaler()

numeric_cols = []
for col in df.select_dtypes(include=['float64', 'int64']).columns:
    if col not in ['year', 'month', 'power', 'time_idx']:
        numeric_cols.append(col)
for col in numeric_cols:
    if col in media_data.columns:
        media_data[col] = (media_data[col] - df[col].min()) / df[col].max()
    df[col] = (df[col] - df[col].min()) / df[col].max()
    df[col] = df[col].astype(np.float32)
df

Unnamed: 0,country,brand,year,month,power,avgtemp,maxtemp,mintemp,prcp,inflation_rate,unemployment_rate,volume_hl,price_usd,sales_usd,wd,time_idx,month_sin,month_cos
0,brazil,brahma,2021,3,52.0,0.068728,0.045359,0.169615,0.705794,0.262406,0.465659,0.000000,0.000000,0.000000,0.000000,0,2.0,1.000000
1,brazil,budweiser,2021,3,19.8,0.068728,0.045359,0.169615,0.705794,0.262406,0.465659,0.000000,0.000000,0.000000,0.000000,0,2.0,1.000000
2,brazil,corona,2021,3,11.5,0.068728,0.045359,0.169615,0.705794,0.262406,0.465659,0.000000,0.000000,0.000000,0.000000,0,2.0,1.000000
3,brazil,spaten,2021,3,4.0,0.068728,0.045359,0.169615,0.705794,0.262406,0.465659,0.000000,0.000000,0.000000,0.000000,0,2.0,1.000000
4,brazil,stella artois,2021,3,12.7,0.068728,0.045359,0.169615,0.705794,0.262406,0.465659,0.000000,0.000000,0.000000,0.000000,0,2.0,1.000000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
200,brazil,brahma,2024,7,46.4,0.013482,0.013491,0.046917,0.065781,0.075721,0.009396,0.764100,0.570715,0.851189,0.644790,40,0.5,0.133975
201,brazil,budweiser,2024,7,19.0,0.013482,0.013491,0.046917,0.065781,0.075721,0.009396,0.125772,0.616326,0.151304,0.630479,40,0.5,0.133975
202,brazil,corona,2024,7,13.5,0.013482,0.013491,0.046917,0.065781,0.075721,0.009396,0.025303,1.000000,0.049390,0.886875,40,0.5,0.133975
203,brazil,spaten,2024,7,9.3,0.013482,0.013491,0.046917,0.065781,0.075721,0.009396,0.087091,0.770952,0.131055,0.733525,40,0.5,0.133975


In [55]:
media_data['spend'] = (media_data['spend'] - media_data['spend'].min()) / media_data['spend'].max()
media_data

Unnamed: 0,country,brand,vehicle,year,month,spend,time_idx,month_sin,month_cos
0,brazil,brahma,digital,2021,3,0.027868,0,2.0,1.000000
1,brazil,budweiser,digital,2021,3,0.002530,0,2.0,1.000000
2,brazil,corona,digital,2021,3,0.005555,0,2.0,1.000000
3,brazil,spaten,digital,2021,3,0.000026,0,2.0,1.000000
4,brazil,stella artois,digital,2021,3,0.000585,0,2.0,1.000000
...,...,...,...,...,...,...,...,...,...
429,brazil,spaten,digital,2024,7,0.017993,40,0.5,0.133975
430,brazil,spaten,ooh,2024,7,0.009553,40,0.5,0.133975
431,brazil,spaten,paytv,2024,7,0.000152,40,0.5,0.133975
432,brazil,stella artois,digital,2024,7,0.013048,40,0.5,0.133975


In [None]:
months_for_test = 3
index_for_train_test_cut = df['time_idx'].max() - months_for_test
train_cut_off = df[df['time_idx'] <= index_for_train_test_cut].index.max()
print(train_cut_off)

In [None]:
# group id by country-year-month
df['group_id'] = df.groupby(['country', 'year', 'month']).ngroup()

group_id = df['group_id'].values
print(group_id)

In [None]:
# time_idx variable
time_idx = (df['time_idx'] + 1).values
print(time_idx)

# month sin and cos
month_sin = df['month_sin'].values
month_cos = df['month_cos'].values
print(month_sin)
print(month_cos)

In [None]:
# country and brand categorical variables
country_id = df['country'].astype('category').cat.codes.values
brand_id = df['brand'].astype('category').cat.codes.values
print(country_id)
print(brand_id)

In [None]:
# controls
n_controls = 11
control_cols = [x for x in df.columns if x not in ['country', 'brand', 'year', 'month', 'power', 'time_idx', 'month_sin', 'month_cos', 'group_id']]
n_controls = len(control_cols)

control_ids = df[control_cols]


print(n_controls, control_cols, control_ids)
print()
print(np.array(control_cols).shape)

In [None]:
y_true = df['power'].values
print(y_true.shape)

In [None]:
def fix_shape(arr, dtype=None):
    arr = np.array(arr)  # ensure ndarray
    arr = arr.reshape(arr.shape[0], -1)
    print(arr.shape)
    if arr.shape[0] < arr.shape[1]:
        arr = arr.T
    if dtype is not None:
        arr = arr.astype(dtype)
    print(arr.shape)
    print()
    return arr

time_idx   = fix_shape(time_idx, dtype="int32")
month_sin  = fix_shape(month_sin, dtype="float32")
month_cos  = fix_shape(month_cos, dtype="float32")
country_id = fix_shape(country_id, dtype="int32")
brand_id   = fix_shape(brand_id, dtype="int32")
controls   = fix_shape(control_ids, dtype="float32")
y_true     = np.array(y_true, dtype="float32").reshape(-1)   # (N,)
group_id   = np.array(group_id, dtype="int32").reshape(-1)   # (N,)

In [None]:
# train/test split
def train_test_split(arr, train_cut_off=train_cut_off):
    """Split the data into training and testing sets. 
    Input:
        arr: numpy array of (rows, cols)
        train_cut_off: index to split the data into training and testing sets

    Args:
        arr (np.ndarray): Input array to split.
        train_cut_off (int, optional): Index to split the data into training and testing sets. Defaults to train_cut_off.
    Returns:
        tuple: A tuple containing the training and testing sets for that array.
    """
    return arr[:train_cut_off+1], arr[train_cut_off+1:]

time_idx_train, time_idx_test = train_test_split(time_idx)
month_sin_train, month_sin_test = train_test_split(month_sin)
month_cos_train, month_cos_test = train_test_split(month_cos)
country_id_train, country_id_test = train_test_split(country_id)
brand_id_train, brand_id_test = train_test_split(brand_id)
controls_train, controls_test = train_test_split(controls)
y_true_train, y_true_test = train_test_split(y_true)    
group_id_train, group_id_test = train_test_split(group_id)

print(time_idx_train.shape, time_idx_test.shape)

In [None]:
# controls_test = np.zeros_like(controls_test)

# Pickle Output

In [None]:
# dump in a pickle file
output_path = "C:/Users/40107904/OneDrive - Anheuser-Busch InBev/ABI/WORK/hackathon_power/hackathon_lt_equity/dummy_data/processed_data"
output_file = f"{output_path}/preprocessed_data.pkl"

input_data_dict = {
    "n_controls": n_controls,
    "control_cols": control_cols,
    # "control_scaler": scaler,
    "time_idx": time_idx_train,
    "month_sin": month_sin_train,
    "month_cos": month_cos_train,
    "country_id": country_id_train,
    "brand_id": brand_id_train,
    "controls": controls_train,
    "y_true": y_true_train,
    "group_id": group_id_train
}

output_data_dict = {
    "time_idx": time_idx_test,
    "month_sin": month_sin_test,
    "month_cos": month_cos_test,
    "country_id": country_id_test,
    "brand_id": brand_id_test,
    "controls": controls_test,
    "y_true": y_true_test,
    "group_id": group_id_test
}

final_dict = {
    "input_data": input_data_dict,
    "output_data": output_data_dict
}

with open(output_file, "wb") as f:
    pkl.dump(final_dict, f)