# ANN for Customer Churn (PyTorch)

**Dataset**: Place `Churn_Modelling.csv` next to this notebook.

**Goal**: Build, tune, and evaluate a feedforward ANN for binary classification (churn).

**References**:
- Kaggle dataset: Deep Learning A-Z ANN (Churn Modelling)
- Article: Building an ANN with PyTorch (Jillani Soft Tech)


In [None]:
# Imports and setup
import os, random, numpy as np, pandas as pd
import torch, torch.nn as nn, torch.optim as optim
from torch.utils.data import Dataset, DataLoader

try:
    from sklearn.metrics import roc_auc_score, confusion_matrix, classification_report
    SKLEARN_AVAILABLE = True
except Exception:
    SKLEARN_AVAILABLE = False

SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using:", DEVICE)


In [None]:
# Load dataset
DATA_PATH = "Churn_Modelling.csv"
if not os.path.exists(DATA_PATH):
    raise FileNotFoundError("Place 'Churn_Modelling.csv' in this folder or update DATA_PATH.")
raw_df = pd.read_csv(DATA_PATH)
print("Shape:", raw_df.shape)
raw_df.head()


In [None]:
# Quick EDA
print("Columns:", list(raw_df.columns))
print("Class distribution (Exited):
", raw_df['Exited'].value_counts())
print("Missing values:
", raw_df.isna().sum())


In [None]:
# Preprocessing
df = raw_df.drop(columns=['RowNumber','CustomerId','Surname'])
df['Gender'] = df['Gender'].map({'Male':1,'Female':0}).astype(int)
df = pd.get_dummies(df, columns=['Geography'], drop_first=True)
X = df.drop(columns=['Exited']).values.astype(np.float32)
y = df['Exited'].values.astype(np.int64)

N = X.shape[0]
idx = np.arange(N); np.random.shuffle(idx)
train_end = int(0.6*N); val_end = int(0.8*N)
train_idx, val_idx, test_idx = idx[:train_end], idx[train_end:val_end], idx[val_end:]
X_train, y_train = X[train_idx], y[train_idx]
X_val, y_val = X[val_idx], y[val_idx]
X_test, y_test = X[test_idx], y[test_idx]

mu = X_train.mean(axis=0); sigma = X_train.std(axis=0) + 1e-8
X_train = (X_train - mu)/sigma
X_val = (X_val - mu)/sigma
X_test = (X_test - mu)/sigma

X_train_t, y_train_t = torch.from_numpy(X_train), torch.from_numpy(y_train)
X_val_t,   y_val_t   = torch.from_numpy(X_val),   torch.from_numpy(y_val)
X_test_t,  y_test_t  = torch.from_numpy(X_test),  torch.from_numpy(y_test)

class ChurnDataset(Dataset):
    def __init__(self,X,y): self.X, self.y = X, y
    def __len__(self): return self.X.shape[0]
    def __getitem__(self,i): return self.X[i], self.y[i]

train_ds, val_ds, test_ds = ChurnDataset(X_train_t,y_train_t), ChurnDataset(X_val_t,y_val_t), ChurnDataset(X_test_t,y_test_t)
print(f"Train/Val/Test: {len(train_ds)}/{len(val_ds)}/{len(test_ds)}")


In [None]:
# Model definition
class ANNNet(nn.Module):
    def __init__(self, input_dim, hidden_dim=32, num_layers=1, dropout=0.2):
        super().__init__()
        layers = []; d = input_dim
        for _ in range(num_layers):
            layers += [nn.Linear(d, hidden_dim), nn.ReLU()]
            if dropout>0: layers += [nn.Dropout(dropout)]
            d = hidden_dim
        self.backbone = nn.Sequential(*layers)
        self.head = nn.Linear(d, 1)
    def forward(self, x):
        return self.head(self.backbone(x)).view(-1)


In [None]:
# Training utilities
from math import isfinite

def loader(ds, bs=64, shuffle=True):
    return DataLoader(ds, batch_size=bs, shuffle=shuffle)

def evaluate(model, dl):
    model.eval(); logits=[], targets=[]
    with torch.no_grad():
        for xb,yb in dl:
            xb, yb = xb.to(DEVICE), yb.to(DEVICE)
            l = model(xb)
            logits.append(l.cpu().numpy()); targets.append(yb.cpu().numpy())
    import numpy as np
    L = np.concatenate(logits); T = np.concatenate(targets)
    P = 1/(1+np.exp(-L)); pred = (P>=0.5).astype(int)
    acc = (pred==T).mean(); metrics={'accuracy':float(acc)}
    if SKLEARN_AVAILABLE:
        try:
            metrics['roc_auc'] = float(roc_auc_score(T,P))
        except Exception: pass
    return metrics

def train_model(hp):
    model = ANNNet(X_train.shape[1], hp.get('hidden_dim',32), hp.get('num_layers',1), hp.get('dropout',0.2)).to(DEVICE)
    opt = optim.Adam(model.parameters(), lr=hp.get('lr',1e-3), weight_decay=hp.get('weight_decay',0.0))
    crit = nn.BCEWithLogitsLoss()
    bs = hp.get('batch_size',64); epochs = hp.get('epochs',30); patience = hp.get('patience',5)
    dl_tr, dl_va = loader(train_ds, bs), loader(val_ds, bs, shuffle=False)
    best_acc=-1; best=None; hist={'train_loss':[], 'val_loss':[], 'val_acc':[]}
    noimp=0
    for ep in range(1,epochs+1):
        model.train(); tot=0
        for xb,yb in dl_tr:
            xb, yb = xb.to(DEVICE), yb.float().to(DEVICE)
            opt.zero_grad(); logit = model(xb); loss = crit(logit,yb); loss.backward(); opt.step()
            tot += loss.item()*xb.size(0)
        tr_loss = tot/len(train_ds)
        # val
        model.eval(); vt=0
        with torch.no_grad():
            for xb,yb in dl_va:
                xb, yb = xb.to(DEVICE), yb.float().to(DEVICE)
                l = crit(model(xb), yb); vt += l.item()*xb.size(0)
        va_loss = vt/len(val_ds); va_metrics = evaluate(model, dl_va)
        hist['train_loss'].append(tr_loss); hist['val_loss'].append(va_loss); hist['val_acc'].append(va_metrics['accuracy'])
        if va_metrics['accuracy']>best_acc:
            best_acc = va_metrics['accuracy']; best={'state':model.state_dict(), 'hp':hp, 'ep':ep}; noimp=0
        else:
            noimp+=1
        if noimp>=patience: print(f"Early stop @epoch {ep}"); break
        if ep%5==0: print(f"Epoch {ep}: tr_loss={tr_loss:.4f} va_loss={va_loss:.4f} va_acc={va_metrics['accuracy']:.4f}")
    if best: model.load_state_dict(best['state'])
    return model, hist, best_acc


In [None]:
# Small random hyperparameter search
space = {
    'lr':[1e-3,5e-4,1e-4],
    'hidden_dim':[16,32,64],
    'num_layers':[1,2],
    'dropout':[0.0,0.2,0.5],
    'batch_size':[32,64,128],
    'weight_decay':[0.0,1e-5,1e-4]
}
trials=[]
for _ in range(10):
    hp = {k: random.choice(v) for k,v in space.items()}
    hp['epochs']=30; hp['patience']=5
    model,hist,va = train_model(hp)
    trials.append({'hp':hp,'va':float(va),'hist':hist})
    print('Trial:', hp, '| val_acc=', va)

best = max(trials, key=lambda d:d['va'])
print('Best hparams:', best['hp'])
print('Best val acc:', best['va'])


In [None]:
# Final training (train+val) and test evaluation
X_trval = np.vstack([X_train, X_val]).astype(np.float32)
y_trval = np.concatenate([y_train, y_val]).astype(np.int64)
X_trval_t = torch.from_numpy(X_trval); y_trval_t = torch.from_numpy(y_trval)
trval_ds = ChurnDataset(X_trval_t,y_trval_t)

bhp = best['hp'].copy(); bhp['epochs']=40; bhp['patience']=6
model,hist,_ = train_model(bhp)

# Test
dl_te = DataLoader(test_ds, batch_size=bhp.get('batch_size',64), shuffle=False)
metrics = evaluate(model, dl_te)
print('Test metrics:', metrics)

if SKLEARN_AVAILABLE:
    model.eval(); probs=[]; targets=[]
    with torch.no_grad():
        for xb,yb in dl_te:
            xb = xb.to(DEVICE)
            p = torch.sigmoid(model(xb)).cpu().numpy()
            probs.append(p); targets.append(yb.cpu().numpy())
    import numpy as np
    P = np.concatenate(probs); T = np.concatenate(targets)
    preds = (P>=0.5).astype(int)
    print('
Confusion Matrix:
', confusion_matrix(T,preds))
    print('
Classification Report:
', classification_report(T,preds))


In [None]:
# Save model and preprocessing
import json, os
ARTIFACT_DIR='models'; os.makedirs(ARTIFACT_DIR, exist_ok=True)
artifact={'train_mean':mu.tolist(),'train_std':sigma.tolist(),'best_hparams':best['hp']}

torch.save({'state_dict':model.state_dict(),'artifact':artifact}, os.path.join(ARTIFACT_DIR,'best_ann.pth'))
with open(os.path.join(ARTIFACT_DIR,'preprocessing.json'),'w') as f: json.dump(artifact,f,indent=2)
print('Artifacts saved to', ARTIFACT_DIR)


## Hyperparameter Tuning Notes
- lr: start 1e-3 for Adam; reduce if unstable.
- hidden_dim/num_layers: increase capacity if underfitting; reduce to avoid overfitting.
- dropout: 0.2–0.5 helps regularize.
- batch_size: 64 default; larger batches smooth gradients.
- weight_decay: 1e-5–1e-4 adds L2 regularization.
- early stopping: patience 5–6 based on validation accuracy.