In [1]:
import os

print("Exists:", os.path.exists('/content/drive/MyDrive/csad_project/data/multi_region_mmm.csv'))


Exists: False


In [2]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [4]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset
from sklearn.model_selection import GroupShuffleSplit
from sklearn.metrics import mean_squared_error, r2_score
import matplotlib.pyplot as plt
from copy import deepcopy




# ===============================
# 📦 Data Preparation
# ===============================


df = pd.read_csv('/content/drive/MyDrive/csad_project/data/multi_region_mmm.csv')
df['DATE_DAY'] = pd.to_datetime(df['DATE_DAY'])
df['week'] = df['DATE_DAY'].dt.isocalendar().week
df['year'] = df['DATE_DAY'].dt.year


agg_df = df.groupby(['ORGANISATION_ID', 'TERRITORY_NAME', 'week', 'year']).agg({
    'GOOGLE_PAID_SEARCH_SPEND': 'sum',
    'META_FACEBOOK_SPEND': 'sum',
    'META_INSTAGRAM_SPEND': 'sum',
    'GOOGLE_VIDEO_SPEND': 'sum',
    'EMAIL_CLICKS': 'sum',
    'REFERRAL_CLICKS': 'sum',
    'GOOGLE_DISPLAY_SPEND': 'sum',
    'BRANDED_SEARCH_CLICKS': 'sum',
    'ALL_PURCHASES': 'sum'
}).reset_index()


agg_df = agg_df.rename(columns={
    'ORGANISATION_ID': 'brand_id',
    'TERRITORY_NAME': 'region',
    'GOOGLE_PAID_SEARCH_SPEND': 'search',
    'META_FACEBOOK_SPEND': 'facebook',
    'META_INSTAGRAM_SPEND': 'instagram',
    'GOOGLE_VIDEO_SPEND': 'video',
    'EMAIL_CLICKS': 'email',
    'REFERRAL_CLICKS': 'affiliate',
    'GOOGLE_DISPLAY_SPEND': 'display',
    'BRANDED_SEARCH_CLICKS': 'promotion',
    'ALL_PURCHASES': 'sales'
})


agg_df['week_idx'] = pd.factorize(agg_df['year'].astype(str) + '-' + agg_df['week'].astype(str))[0]
agg_df['sales'] = (agg_df['sales'] - agg_df['sales'].mean()) / (agg_df['sales'].std() + 1e-6)


# ===============================
# 🧱 Dataset
# ===============================


class CSADSequenceDataset(Dataset):
    def __init__(self, df, window_size=32):
        self.window_size = window_size
        self.media_cols = ['search', 'facebook', 'instagram', 'video', 'email', 'affiliate', 'display', 'promotion']
        self.brand2idx = {b: i for i, b in enumerate(df['brand_id'].unique())}
        self.region2idx = {r: i for i, r in enumerate(df['region'].unique())}
        df['brand_idx'] = df['brand_id'].map(self.brand2idx)
        df['region_idx'] = df['region'].map(self.region2idx)


        self.samples = []
        for (brand, region), group in df.groupby(['brand_id', 'region']):
            group = group.sort_values(by='week_idx')
            if len(group) < window_size + 1:
                continue
            for i in range(len(group) - window_size):
                window = group.iloc[i:i+window_size]
                target = group.iloc[i+window_size]['sales']
                self.samples.append({
                    'media_seq': window[self.media_cols].values.astype(np.float32),
                    'week_idx_seq': window['week_idx'].values.astype(np.int64),
                    'brand_idx': window['brand_idx'].iloc[0],
                    'region_idx': window['region_idx'].iloc[0],
                    'target_sales': target
                })


    def __len__(self): return len(self.samples)


    def __getitem__(self, idx):
        s = self.samples[idx]
        return {
            'media_seq': torch.tensor(s['media_seq']),
            'week_idx_seq': torch.tensor(s['week_idx_seq']),
            'brand_idx': torch.tensor(s['brand_idx']),
            'region_idx': torch.tensor(s['region_idx']),
            'target_sales': torch.tensor(s['target_sales'], dtype=torch.float32)
        }


# ===============================
# 🧠 Model
# ===============================


class PositionalEncoder(nn.Module):
    def __init__(self, max_weeks, embed_dim):
        super().__init__()
        self.week_embedding = nn.Embedding(max_weeks, embed_dim)


    def forward(self, week_idx):
        return self.week_embedding(week_idx)


class VariableSelectionGate(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.gate = nn.Sequential(
            nn.Linear(input_dim, input_dim),
            nn.ReLU(),
            nn.Linear(input_dim, 1),
            nn.Sigmoid()
        )


    def forward(self, x):
        weight = self.gate(x)
        return weight * x


class ChannelEncoder(nn.Module):
    def __init__(self, week_embed_dim=8, hidden_dim=32):
        super().__init__()
        self.input_dim = 1 + week_embed_dim
        self.proj = nn.Linear(self.input_dim, hidden_dim)
        self.attn_score = nn.Linear(hidden_dim, 1)
        self.norm = nn.LayerNorm(hidden_dim)


    def forward(self, spend_seq, week_embed):
        x = torch.cat([spend_seq, week_embed], dim=-1)
        x_proj = self.proj(x)
        scores = self.attn_score(x_proj).squeeze(-1)
        weights = torch.softmax(scores, dim=1).unsqueeze(-1)
        x_weighted = (x_proj * weights).sum(dim=1)
        return self.norm(x_weighted)


class GraphGuidtedMultiheadAttenion(nn.Module):
    def __init__(self, embed_dim, num_heads, num_channels):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
        self.graph_adj = nn.Parameter(torch.ones(num_channels, num_channels))


    def forward(self, q, k, v):
        attn_out, attn_weights = self.attn(q, k, v)
        masked_weights = attn_weights * self.graph_adj.unsqueeze(0)
        output = torch.bmm(masked_weights, v)
        return output, masked_weights


class CSADModel(nn.Module):
    def __init__(self, num_channels=8, week_embed_dim=8, hidden_dim=64, context_embed_dim=8,
                 brand_vocab_size=100, region_vocab_size=20, max_weeks=300):
        super().__init__()
        self.pos_encoder = PositionalEncoder(max_weeks, week_embed_dim)
        self.gates = nn.ModuleList([VariableSelectionGate(1) for _ in range(num_channels)])
        self.channel_encoders = nn.ModuleList([ChannelEncoder(week_embed_dim, hidden_dim) for _ in range(num_channels)])
        self.cross_attn = GraphGuidtedMultiheadAttenion(embed_dim=hidden_dim, num_heads=1, num_channels=num_channels)
        self.cross_norm = nn.LayerNorm(hidden_dim * num_channels)
        self.brand_emb = nn.Embedding(brand_vocab_size, context_embed_dim)
        self.region_emb = nn.Embedding(region_vocab_size, context_embed_dim)
        self.brand_proj = nn.Linear(context_embed_dim, hidden_dim)
        self.region_proj = nn.Linear(context_embed_dim, hidden_dim)
        self.context_gate = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 1), nn.Sigmoid()
        )
        self.fc = nn.Sequential(
            nn.Linear(hidden_dim * num_channels + hidden_dim, 256), nn.ReLU(), nn.Dropout(0.3),
            nn.Linear(256, 128), nn.ReLU(), nn.Dropout(0.2),
            nn.Linear(128, 64), nn.ReLU(), nn.Linear(64, 1)
        )


    def forward(self, media_seq, week_idx_seq, brand_idx, region_idx):
        B, T, C = media_seq.shape
        week_embed = self.pos_encoder(week_idx_seq)
        reps = []
        for i in range(C):
            x = media_seq[:, :, i].unsqueeze(-1)
            x = self.gates[i](x)
            rep = self.channel_encoders[i](x, week_embed)
            reps.append(rep)
        x_stack = torch.stack(reps, dim=1)
        attn_out, _ = self.cross_attn(x_stack, x_stack, x_stack)
        x_flat = self.cross_norm(attn_out.reshape(B, -1))
        b_proj = self.brand_proj(self.brand_emb(brand_idx))
        r_proj = self.region_proj(self.region_emb(region_idx))
        gate = self.context_gate(torch.cat([b_proj, r_proj], dim=-1))
        context = gate * b_proj + (1 - gate) * r_proj
        return self.fc(torch.cat([x_flat, context], dim=-1)).squeeze(-1)


# ===============================
# 🧪 Train/Test Utilities
# ===============================


def build_loaders(df, window_size=32, batch_size=32):
    dataset = CSADSequenceDataset(df, window_size)
    group_to_indices = {}
    for idx, s in enumerate(dataset.samples):
        k = (int(s['brand_idx']), int(s['region_idx']))
        group_to_indices.setdefault(k, []).append(idx)
    all_groups = list(group_to_indices.keys())
    gss = GroupShuffleSplit(n_splits=1, train_size=0.8, random_state=42)
    train_g, val_g = next(gss.split(np.arange(len(all_groups)), groups=np.arange(len(all_groups))))
    train_idx = [idx for g in train_g for idx in group_to_indices[all_groups[g]]]
    val_idx = [idx for g in val_g for idx in group_to_indices[all_groups[g]]]
    return dataset, DataLoader(Subset(dataset, train_idx), batch_size, shuffle=True), DataLoader(Subset(dataset, val_idx), batch_size)


@torch.no_grad()
def mc_dropout_predict(model, batch, T=30):
    model.eval()
    for m in model.modules():
        if isinstance(m, nn.Dropout): m.train()
    preds = [model(batch['media_seq'], batch['week_idx_seq'], batch['brand_idx'], batch['region_idx']).cpu().numpy() for _ in range(T)]
    return np.stack(preds).mean(0), np.stack(preds).std(0)


@torch.no_grad()
def counterfactual_predict(model, factual_batch, intervention_fn, T=30):
    counter_batch = {k: v.clone() for k, v in factual_batch.items()}
    counter_batch = intervention_fn(counter_batch)
    return mc_dropout_predict(model, counter_batch, T)


def train_one_epoch(model, loader, loss_fn, optimizer, device):
    model.train()
    total_loss = 0
    for batch in loader:
        batch = {k: v.to(device) for k, v in batch.items()}
        pred = model(batch['media_seq'], batch['week_idx_seq'], batch['brand_idx'], batch['region_idx'])
        loss = loss_fn(pred, batch['target_sales'])
        optimizer.zero_grad(); loss.backward(); optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)


@torch.no_grad()
def evaluate(model, loader, loss_fn, device, use_mc=False):
    model.eval(); y_true, y_pred, stds = [], [], []
    for batch in loader:
        batch = {k: v.to(device) for k, v in batch.items()}
        y_true.extend(batch['target_sales'].cpu().numpy())
        if use_mc:
            m, s = mc_dropout_predict(model, batch); y_pred.extend(m); stds.extend(s)
        else:
            y_pred.extend(model(batch['media_seq'], batch['week_idx_seq'], batch['brand_idx'], batch['region_idx']).cpu().numpy())
    mse = mean_squared_error(y_true, y_pred)
    return mse, np.sqrt(mse), r2_score(y_true, y_pred), stds if use_mc else None


def plot_losses(train, val):
    plt.plot(train, label="Train"); plt.plot(val, label="Val")
    plt.legend(); plt.title("Loss"); plt.grid(); plt.show()


# ===============================
# 🚀 Training Loop
# ===============================


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
brand_vocab_size = agg_df['brand_id'].nunique()
region_vocab_size = agg_df['region'].nunique()


model = CSADModel(brand_vocab_size=brand_vocab_size, region_vocab_size=region_vocab_size).to(device)
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)


seq_dataset, train_loader, val_loader = build_loaders(agg_df)


train_losses, val_losses = [], []
best_val_rmse = float('inf')
early_stop_counter = 0
patience = 5


for epoch in range(25):
    train_loss = train_one_epoch(model, train_loader, loss_fn, optimizer, device)
    mse, rmse, r2, _ = evaluate(model, val_loader, loss_fn, device, use_mc=True)
    train_losses.append(train_loss); val_losses.append(mse)
    print(f"Epoch {epoch+1:03d} | Train Loss: {train_loss:.4f} | Val RMSE: {rmse:.4f} | R²: {r2:.4f}")
    if rmse < best_val_rmse:
        best_val_rmse = rmse
        best_model_state = deepcopy(model.state_dict())
        early_stop_counter = 0
        torch.save(model.state_dict(), "csad_best_model.pt")
        print("✅ Saved best model (RMSE improved)")
    else:
        early_stop_counter += 1
        if early_stop_counter >= patience:
            print("🛑 Early stopping triggered")
            break


plot_losses(train_losses, val_losses)




Epoch 001 | Train Loss: 1.5858 | Val RMSE: 0.2083 | R²: -0.1300
✅ Saved best model (RMSE improved)


KeyboardInterrupt: 