In [1]:
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
import numpy as np

In [2]:
import jax.numpy as jnp
import jax
import optax

In [3]:
from torch.utils.data import Dataset, DataLoader

In [4]:
key = jax.random.PRNGKey(0)
key

Array([0, 0], dtype=uint32)

In [5]:
from collections import namedtuple

Batch = namedtuple("Batch", ["X", "y"])

In [6]:
X, y = make_classification(
    n_classes=2,
    n_samples=1_000_000,
    random_state=0,
    n_features=20,
    n_informative=10,
    n_redundant=7,
    n_repeated=3,
    flip_y=0.05,
    class_sep=0.5,
)

In [7]:
class BinDataset(Dataset):
    def __init__(self, X, y):
        self._X = X
        self._y = y

    def __getitem__(self, idx):
        return self._X[idx], self._y[idx]

    def __len__(self):
        return self._X.shape[0]


def collate(samples):
    xs, ys = zip(*samples)
    X = jnp.vstack(jnp.array(x) for x in xs)
    y = jnp.array([y for y in ys])
    return Batch(X, y)        

In [8]:
X_train, X_val, y_train, y_val = train_test_split(X, y, train_size=0.8)
trainset = BinDataset(X_train, y_train)
valset = BinDataset(X_val, y_val)
print(len(trainset), len(valset))

800000 200000


In [9]:
def model(params, X):
    for param in params[:-1]:
        W, b = param["W"], param["b"]
        X = jax.nn.relu(X @ W + b)
    W, b = params[-1]["W"], params[-1]["b"]
    # p = jax.nn.sigmoid(X @ W.T + b)
    logits = X @ W + b
    return logits

In [10]:
def init_linear(key, in_features, out_features):
    W = jax.nn.initializers.glorot_normal()(key, (in_features, out_features))
    var = jnp.sqrt(out_features)
    b = jax.random.uniform(key, (out_features,), minval=-var, maxval=var)
    if out_features == 1:
        W = W.squeeze()
        b = b.squeeze()
    return {"W": W, "b": b}

In [11]:
@jax.jit
def loss(params, batch):
    X, y = batch.X, batch.y
    logits = model(params, X)
    bce = optax.sigmoid_binary_cross_entropy(logits, y)
    mean_bce = jnp.mean(bce)
    return mean_bce

In [12]:
hyperparams = {"learning_rate": 0.006, "n_epochs": 4, "batch_size": 32}

In [13]:
optim = optax.adam(learning_rate=hyperparams["learning_rate"])
params = [init_linear(key, 20, 8), init_linear(key, 8, 1)]
traindl = DataLoader(
    trainset,
    batch_size=hyperparams["batch_size"],
    shuffle=True,
    drop_last=True,
    collate_fn=collate,
)
valdl = DataLoader(
    valset,batch_size=10000, shuffle=False, drop_last=False, collate_fn=collate
)


In [14]:
@jax.jit
def step(params, opt_state, batch):
    loss_val, grads = jax.value_and_grad(loss)(params, batch)
    updates, opt_state = optim.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss_val

In [18]:
# optim = optax.adam(learning_rate=hyperparams["learning_rate"])
optim = optax.adam(learning_rate=0.0006)
opt_state = optim.init(params)
train_losses = []
# for epoch in range(hyperparams["n_epochs"]):
for epoch in range(2):
    for i, batch in enumerate(traindl):
        params, opt_state, train_loss = step(params, opt_state, batch)
        train_losses.append(train_loss)
        if i % 10000 == 0:
            val_losses = []
            for valbatch in valdl:
                val_loss = loss(params, valbatch)
                val_losses.append(val_loss)
            val_loss = jnp.mean(jnp.array(val_losses))
            train_loss = jnp.mean(jnp.array(train_losses))
            train_losses = []
            print(f"Epoch {epoch} Step {i}: Val Loss = {val_loss:.5f}   Train Loss = {train_loss:.5f}")


Epoch 0 Step 0: Val Loss = 0.26663   Train Loss = 0.66109
Epoch 0 Step 10000: Val Loss = 0.25336   Train Loss = 0.25450
Epoch 0 Step 20000: Val Loss = 0.24539   Train Loss = 0.24942
Epoch 1 Step 0: Val Loss = 0.25001   Train Loss = 0.25489
Epoch 1 Step 10000: Val Loss = 0.25848   Train Loss = 0.25245
Epoch 1 Step 20000: Val Loss = 0.25031   Train Loss = 0.25255


```
Epoch 0 Step 0: Val Loss = 2.38983   Train Loss = 2.51969
Epoch 0 Step 10000: Val Loss = 0.35811   Train Loss = 0.39189
Epoch 0 Step 20000: Val Loss = 0.33252   Train Loss = 0.34015
Epoch 0 Step 30000: Val Loss = 0.31942   Train Loss = 0.32632
Epoch 0 Step 40000: Val Loss = 0.31470   Train Loss = 0.32119
Epoch 1 Step 0: Val Loss = 0.31439   Train Loss = 0.31875
Epoch 1 Step 10000: Val Loss = 0.31506   Train Loss = 0.31685
Epoch 1 Step 20000: Val Loss = 0.31275   Train Loss = 0.31332
Epoch 1 Step 30000: Val Loss = 0.30700   Train Loss = 0.31184
Epoch 1 Step 40000: Val Loss = 0.30697   Train Loss = 0.30810
```