# Teacher (DNN) → Student (Tiny DNN) Distillation pipeline

**Objetivo:** entrenar una DNN *teacher* más grande offline y luego distilar su conocimiento a un *student* pequeño y rápido para predicción de `iap_revenue_d7` (target).

Este notebook está diseñado para datasets tabulares como el del hackathon (features de request + historial de usuario). Incluye:
- carga y preprocesado (intenta detectar archivos en `/mnt/data`)
- definición de modelos en **PyTorch** (teacher y student)
- esquema de entrenamiento del teacher (rápido, ejemplar)
- esquema de distillation (loss combinado: `L_hard + L_soft`)
- evaluación básica (MSLE y AUC para buyer flag)

> Nota: Este notebook se entrega como plantilla lista para ejecutar. Ajusta hiperparámetros, dimensiones de embeddings y número de épocas según tu hardware y tamaño del dataset.

In [None]:
import os
import math
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.metrics import mean_squared_log_error, roc_auc_score
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import joblib

# Reproducibility
RSEED = 42
np.random.seed(RSEED)
torch.manual_seed(RSEED)


In [None]:
# Try to locate a dataset in /mnt/data
candidates = [f for f in os.listdir('/mnt/data') if f.lower().endswith(('.csv', '.parquet', '.feather', '.pkl', '.zip'))]
print('Datasets found in /mnt/data:', candidates)

data_path = None
if candidates:
    data_path = os.path.join('/mnt/data', candidates[0])

if data_path is None:
    print('No dataset found in /mnt/data. The notebook will create a SMALL synthetic dataset for demonstration. Replace it with your real dataset.')
else:
    print('Using dataset:', data_path)


In [None]:
# Load dataset or create synthetic
if data_path is not None:
    if data_path.lower().endswith('.csv'):
        data = pd.read_csv(data_path)
    elif data_path.lower().endswith('.parquet'):
        data = pd.read_parquet(data_path)
    elif data_path.lower().endswith('.pkl'):
        data = pd.read_pickle(data_path)
    else:
        data = pd.read_csv(data_path)
else:
    # Create a small synthetic dataset for demonstration
    N = 20000
    rng = np.random.RandomState(RSEED)
    data = pd.DataFrame({
        'device_model': rng.choice([f'dm_{i}' for i in range(200)], size=N),
        'os_version': rng.choice([f'os_{i}' for i in range(10)], size=N),
        'country': rng.choice(['US','BR','IN','DE','FR','CN'], size=N),
        'hour': rng.randint(0,24,size=N),
        'past_purchases': rng.poisson(0.2, size=N),
        'avg_spend_30d': rng.exponential(5.0, size=N),
        'imp_count_7d': rng.poisson(3, size=N),
    })
    buyer_prob = 1 / (1 + np.exp(-(-2 + 0.05*data['past_purchases'] + 0.01*data['imp_count_7d'])))
    is_buyer = rng.binomial(1, buyer_prob)
    revenue = is_buyer * np.exp(rng.normal(np.log(5+data['avg_spend_30d']), 1.0)) * rng.choice([0.5,1,2,5], size=N, p=[0.5,0.3,0.15,0.05])
    revenue = np.clip(revenue, 0, None)
    data['buyer_d7'] = is_buyer
    data['iap_revenue_d7'] = revenue
    print('Synthetic data created.')

print('Data shape:', data.shape)
data.head()


In [None]:
# Basic preprocessing: detect categorical and numeric features, prepare target
target_col = 'iap_revenue_d7' if 'iap_revenue_d7' in data.columns else data.select_dtypes(include=[np.number]).columns[-1]
buyer_col = 'buyer_d7' if 'buyer_d7' in data.columns else None

# Identify categorical columns (object or low-cardinality)
cat_cols = data.select_dtypes(include=['object', 'category']).columns.tolist()
# also treat low-cardinality ints as categorical
for c in data.select_dtypes(include=['int64','int32']).columns:
    if c != target_col and data[c].nunique() <= 50 and c not in cat_cols and c != buyer_col:
        cat_cols.append(c)
        
num_cols = [c for c in data.select_dtypes(include=[np.number]).columns.tolist() if c not in [target_col, buyer_col] and c not in cat_cols]

print('Target:', target_col)
print('Buyer flag:', buyer_col)
print('Categorical columns:', cat_cols)
print('Numeric columns:', num_cols)


In [None]:
# Fillna
if len(cat_cols)>0:
    data[cat_cols] = data[cat_cols].fillna('NA')
if len(num_cols)>0:
    data[num_cols] = data[num_cols].fillna(0.0)

# Label encode categoricals and save encoders
encoders = {}
for c in cat_cols:
    le = LabelEncoder()
    data[c] = le.fit_transform(data[c].astype(str))
    encoders[c] = le

# Target transform: log1p for stability (MSLE)
data['target_log1p'] = np.log1p(data[target_col].clip(lower=0.0))

# Train/val split (stratify by buyer if available)
if buyer_col is not None:
    train_df, val_df = train_test_split(data, test_size=0.15, random_state=RSEED, stratify=data[buyer_col])
else:
    train_df, val_df = train_test_split(data, test_size=0.15, random_state=RSEED)

print('Train shape:', train_df.shape, 'Val shape:', val_df.shape)

# Save encoders for later use
joblib.dump(encoders, '/mnt/data/encoders_joblib.pkl')
print('Encoders saved to /mnt/data/encoders_joblib.pkl')


In [None]:
# PyTorch Dataset for tabular data with categorical embeddings
class TabularDataset(Dataset):
    def __init__(self, df, cat_cols, num_cols, target_col, buyer_col=None):
        self.cat = df[cat_cols].values.astype(np.int64) if len(cat_cols)>0 else None
        self.num = df[num_cols].values.astype(np.float32) if len(num_cols)>0 else None
        self.y = df[target_col].values.astype(np.float32)
        self.buyer = df[buyer_col].values.astype(np.int64) if buyer_col is not None else None
        
    def __len__(self):
        return len(self.y)
    
    def __getitem__(self, idx):
        item = {}
        if self.cat is not None:
            item['cat'] = torch.tensor(self.cat[idx], dtype=torch.long)
        if self.num is not None:
            item['num'] = torch.tensor(self.num[idx], dtype=torch.float32)
        item['y'] = torch.tensor(self.y[idx], dtype=torch.float32)
        if self.buyer is not None:
            item['buyer'] = torch.tensor(self.buyer[idx], dtype=torch.float32)
        return item

# Create datasets and dataloaders (small batch sizes)
train_ds = TabularDataset(train_df, cat_cols, num_cols, 'target_log1p', buyer_col)
val_ds = TabularDataset(val_df, cat_cols, num_cols, 'target_log1p', buyer_col)

batch_size = 256
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=False)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, drop_last=False)

print('Dataloaders ready.')


In [None]:
# Model utilities: Embedding sizes rule
def get_embedding_sizes(df, cat_cols, max_emb_dim=50):
    emb_sizes = []
    for c in cat_cols:
        n_unique = int(df[c].nunique())
        emb_dim = min(max(1, n_unique//10), max_emb_dim)
        emb_sizes.append((n_unique, emb_dim))
    return emb_sizes

# Teacher model (bigger)
class TeacherModel(nn.Module):
    def __init__(self, emb_sizes, num_len):
        super().__init__()
        self.embs = nn.ModuleList([nn.Embedding(categories, dim) for categories, dim in emb_sizes])
        emb_dim_sum = sum([dim for _, dim in emb_sizes]) if len(emb_sizes)>0 else 0
        input_dim = emb_dim_sum + (num_len if num_len>0 else 0)
        self.net = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Dropout(0.2),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Dropout(0.15),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )
        self.buyer_head = nn.Sequential(
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )
        
    def forward(self, x_cat, x_num):
        if x_cat is not None and len(self.embs)>0:
            embs = [emb(x_cat[:,i]) for i, emb in enumerate(self.embs)]
            x = torch.cat(embs + ([x_num] if x_num is not None else []), dim=1)
        else:
            x = x_num
        feat = self.net[:-1](x) if isinstance(self.net, nn.Sequential) else self.net(x)
        out_reg = self.net[-1](feat) if isinstance(self.net, nn.Sequential) else self.net(x)
        out_buyer = self.buyer_head(feat)
        return out_reg.view(-1), out_buyer.view(-1)

# Student model (small)
class StudentModel(nn.Module):
    def __init__(self, emb_sizes, num_len):
        super().__init__()
        small_embs = [(n, max(1, d//2)) for n,d in emb_sizes]
        self.embs = nn.ModuleList([nn.Embedding(categories, dim) for categories, dim in small_embs])
        emb_dim_sum = sum([dim for _, dim in small_embs]) if len(small_embs)>0 else 0
        input_dim = emb_dim_sum + (num_len if num_len>0 else 0)
        self.net = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )
    def forward(self, x_cat, x_num):
        if x_cat is not None and len(self.embs)>0:
            embs = [emb(x_cat[:,i]) for i, emb in enumerate(self.embs)]
            x = torch.cat(embs + ([x_num] if x_num is not None else []), dim=1)
        else:
            x = x_num
        out = self.net(x)
        return out.view(-1)


In [None]:
emb_sizes = get_embedding_sizes(train_df, cat_cols, max_emb_dim=50)
print('Embedding sizes (categories, dim):', emb_sizes)

teacher = TeacherModel(emb_sizes, len(num_cols)).to('cpu')
student = StudentModel(emb_sizes, len(num_cols)).to('cpu')

print('Teacher params:', sum(p.numel() for p in teacher.parameters() if p.requires_grad))
print('Student params:', sum(p.numel() for p in student.parameters() if p.requires_grad))


In [None]:
# Training skeleton for teacher
def train_teacher(model, loader, val_loader, epochs=3, lr=1e-3, device='cpu'):
    opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)
    mse_loss = nn.MSELoss()
    bce_loss = nn.BCELoss()
    model.to(device)
    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        for batch in loader:
            x_cat = batch.get('cat', None).to(device) if 'cat' in batch else None
            x_num = batch.get('num', None).to(device) if 'num' in batch else None
            y = batch['y'].to(device)
            buyer = batch.get('buyer', None)
            buyer = buyer.to(device) if buyer is not None else None
            opt.zero_grad()
            pred_log1p, pred_buyer = model(x_cat, x_num)
            loss_reg = mse_loss(pred_log1p, y)
            loss_buyer = bce_loss(pred_buyer, buyer) if buyer is not None else 0.0
            loss = loss_reg + 0.5 * loss_buyer
            loss.backward()
            opt.step()
            train_loss += loss.item() * len(y)
        train_loss /= len(loader.dataset)
        print(f'Epoch {epoch+1}/{epochs} train_loss={train_loss:.6f}')
    torch.save(model.state_dict(), '/mnt/data/teacher_model.pt')
    print('Teacher saved to /mnt/data/teacher_model.pt')

print('Teacher training function defined. (call train_teacher(...) to run)')


In [None]:
# Distillation
def train_student_with_distillation(student, teacher, train_loader, val_loader, epochs=3, lr=1e-3, alpha=0.6, device='cpu'):
    opt = torch.optim.Adam(student.parameters(), lr=lr, weight_decay=1e-5)
    mse = nn.MSELoss()
    teacher.to(device)
    teacher.eval()
    student.to(device)
    for epoch in range(epochs):
        student.train()
        total_loss = 0.0
        for batch in train_loader:
            x_cat = batch.get('cat', None).to(device) if 'cat' in batch else None
            x_num = batch.get('num', None).to(device) if 'num' in batch else None
            y = batch['y'].to(device)
            with torch.no_grad():
                t_pred_log1p, t_pred_buyer = teacher(x_cat, x_num)
            s_pred = student(x_cat, x_num)
            loss_hard = mse(s_pred, y)
            loss_soft = mse(s_pred, t_pred_log1p)
            loss = alpha * loss_hard + (1.0 - alpha) * loss_soft
            opt.zero_grad()
            loss.backward()
            opt.step()
            total_loss += loss.item() * len(y)
        total_loss /= len(train_loader.dataset)
        print(f'Epoch {epoch+1}/{epochs} distill_loss={total_loss:.6f}')
    torch.save(student.state_dict(), '/mnt/data/student_model.pt')
    print('Student saved to /mnt/data/student_model.pt')

print('Distillation training function defined. (call train_student_with_distillation(...) to run)')


In [None]:
# Evaluation utilities
 def predict_model_regression(model, loader, device='cpu'):
    model.to(device)
    model.eval()
    preds = []
    trues = []
    buyers = []
    with torch.no_grad():
        for batch in loader:
            x_cat = batch.get('cat', None).to(device) if 'cat' in batch else None
            x_num = batch.get('num', None).to(device) if 'num' in batch else None
            y = batch['y'].to(device)
            out = model(x_cat, x_num)
            preds.append(out.cpu().numpy())
            trues.append(y.cpu().numpy())
            if 'buyer' in batch:
                buyers.append(batch['buyer'].numpy())
    preds = np.concatenate(preds).ravel()
    trues = np.concatenate(trues).ravel()
    buyers = np.concatenate(buyers).ravel() if buyers else None
    return preds, trues, buyers

def msle_from_log_predictions(pred_log1p, true_log1p):
    pred = np.expm1(pred_log1p)
    true = np.expm1(true_log1p)
    pred = np.clip(pred, 0, None)
    true = np.clip(true, 0, None)
    return mean_squared_log_error(true, pred)

print('Evaluation utilities defined.')


## Cómo usar este notebook

1. Si tienes un dataset, colócalo en `/mnt/data` (CSV o parquet). El notebook auto-detecta archivos ahí.  
2. Ajusta `cat_cols` y `num_cols` si lo deseas (el notebook intenta detectarlos automáticamente).  
3. Entrena el **teacher** (p. ej. `train_teacher(teacher, train_loader, val_loader, epochs=5, lr=1e-3)`).  
4. Después, entrena al **student** con distillation (`train_student_with_distillation(student, teacher, train_loader, val_loader, epochs=5, alpha=0.6)`).  
5. Evalúa usando las funciones de evaluación o guarda modelos para exportar.  

**Tips**:
- Para producción, guarda sólo el `student` y exporta a ONNX + quantization.
- Ajusta `emb_sizes` rule para reducir dims en el student y acelerar inferencia.
- Considera usar `log1p` y MSLE en todas las fases.
