# PyTorch Developer's Guide to JAX Fundamentals

This notebook along with its accompanying blog post serves to help PyTorch users get familiar with Jax/Flax by connecting the building blocks they are already familiar with in PyTorch to their equivalents in Jax/Flax!

### Data Exploration
We will be making use of the classic Titanic dataset. See
https://www.kaggle.com/c/titanic/ for details.

In [1]:
import torch

torch.cuda.is_available()

True

In [2]:
import jax

jax.default_backend()

'gpu'

In [3]:
import pandas as pd

df = pd.read_csv("data/titanic/train.csv")
df.head()

Unnamed: 0,PassengerId,Survived,Pclass,Name,Sex,Age,SibSp,Parch,Ticket,Fare,Cabin,Embarked
0,1,0,3,"Braund, Mr. Owen Harris",male,22.0,1,0,A/5 21171,7.25,,S
1,2,1,1,"Cumings, Mrs. John Bradley (Florence Briggs Th...",female,38.0,1,0,PC 17599,71.2833,C85,C
2,3,1,3,"Heikkinen, Miss. Laina",female,26.0,0,0,STON/O2. 3101282,7.925,,S
3,4,1,1,"Futrelle, Mrs. Jacques Heath (Lily May Peel)",female,35.0,1,0,113803,53.1,C123,S
4,5,0,3,"Allen, Mr. William Henry",male,35.0,0,0,373450,8.05,,S


In [4]:
def preprocess_titanic_data(df: pd.DataFrame) -> tuple[pd.DataFrame, pd.Series]:
    """
    Basic preprocessing of the Titanic dataset for usage within ML training
    """
    df = df.drop(columns=["Name", "Ticket", "Cabin"], axis=1)
    df["Sex"] = df["Sex"].map({"male": 0, "female": 1})
    df["Embarked"] = df["Embarked"].map({"S": 0, "C": 1, "Q": 2})
    df["Age"] = df["Age"].fillna(df["Age"].median())
    df = df.dropna()
    labels = df.pop("Survived")
    return df, labels


df, labels = preprocess_titanic_data(df)
df.head()

Unnamed: 0,PassengerId,Pclass,Sex,Age,SibSp,Parch,Fare,Embarked
0,1,3,0,22.0,1,0,7.25,0.0
1,2,1,1,38.0,1,0,71.2833,1.0
2,3,3,1,26.0,0,0,7.925,0.0
3,4,1,1,35.0,1,0,53.1,0.0
4,5,3,0,35.0,0,0,8.05,0.0


In [5]:
# Split data into train and test/eval
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(
    df, labels, test_size=0.2, random_state=42
)

As a baseline, let's train a simple RandomForestClassifier with this data

In [6]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score

rf = RandomForestClassifier(n_estimators=100, random_state=42)
rf.fit(X_train, y_train)

y_pred = rf.predict(X_test)

print(accuracy_score(y_test, y_pred))

0.797752808988764


### Dataloader
See https://jax.readthedocs.io/en/latest/notebooks/Neural_Network_and_Data_Loading.html for more details

In [7]:
import jax.numpy as jnp
from torch.utils.data import Dataset, DataLoader, default_collate
from jax.tree_util import tree_map


class TitanicDataset(Dataset):
    def __init__(self, samples: pd.DataFrame, labels: pd.Series):
        self.df = samples
        self.labels = labels

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        x = torch.tensor(self.df.iloc[idx].values, dtype=torch.float32)
        y = torch.tensor(self.labels.iloc[idx], dtype=torch.int32)
        return x, y


# This collate function is taken from the JAX tutorial with PyTorch Data Loading
# https://jax.readthedocs.io/en/latest/notebooks/Neural_Network_and_Data_Loading.html
def numpy_collate(batch):
    return tree_map(jnp.asarray, default_collate(batch))


train_dataset = TitanicDataset(X_train, y_train)
train_dataloader_jax = DataLoader(
    train_dataset, batch_size=64, shuffle=True, collate_fn=numpy_collate
)
train_dataloader_torch = DataLoader(train_dataset, batch_size=64, shuffle=True)

eval_dataset = TitanicDataset(X_test, y_test)
eval_dataloader_jax = DataLoader(
    eval_dataset, batch_size=64, shuffle=False, collate_fn=numpy_collate
)
eval_dataloader_torch = DataLoader(eval_dataset, batch_size=64, shuffle=False)

### PyTorch Reference Implementation and Training

Model Definition

In [8]:
import torch.nn as nn
import torch.optim as optim


class TitanicNeuralNet(nn.Module):
    def __init__(self, num_hidden_1, num_hidden_2):
        super().__init__()
        self.linear1 = nn.Linear(8, num_hidden_1)
        self.dropout = nn.Dropout(0.01)
        self.relu = nn.LeakyReLU()
        self.linear2 = nn.Linear(num_hidden_1, num_hidden_2)
        self.linear3 = nn.Linear(num_hidden_2, 1, bias=False)

    def forward(self, x):
        x = self.linear1(x)
        x = self.dropout(x)
        x = self.relu(x)
        x = self.linear2(x)
        x = self.dropout(x)
        x = self.relu(x)
        out = self.linear3(x)
        return out

Training Loop

In [9]:
from tqdm import tqdm


def train(
    model: nn.Module,
    train_dataloader: DataLoader,
    eval_dataloader: DataLoader,
    num_epochs: int,
):
    optimizer = optim.Adam(model.parameters(), lr=0.01)
    criterion = torch.nn.BCEWithLogitsLoss()
    for epoch in (pbar := tqdm(range(num_epochs))):
        pbar.set_description(f"Epoch {epoch}")
        model.train()
        for batch, labels in train_dataloader:
            # zero your gradients
            optimizer.zero_grad()

            # forward pass
            logits = model(batch)

            # compute loss
            loss = criterion(logits.squeeze(), labels.float())

            # backward pass
            loss.backward()

            # take an optimizer step
            optimizer.step()

        pbar.set_postfix(
            train_accuracy=eval(model, train_dataloader),
            eval_accuracy=eval(model, eval_dataloader),
        )


def eval(model: nn.Module, eval_dataloader: DataLoader):
    model.eval()
    num_correct = 0
    num_samples = 0
    for batch, labels in eval_dataloader:
        logits = model(batch)
        preds = torch.round(torch.sigmoid(logits))
        num_correct += (preds.squeeze() == labels).sum().item()
        num_samples += labels.shape[0]
    return num_correct / num_samples


model = TitanicNeuralNet(num_hidden_1=32, num_hidden_2=16)
train(model, train_dataloader_torch, eval_dataloader_torch, num_epochs=500)

Epoch 499: 100%|██████████| 500/500 [00:26<00:00, 18.65it/s, eval_accuracy=0.798, train_accuracy=0.831]


### Flax NNX (Object-Oriented) Neural Net
See https://flax.readthedocs.io/en/latest/index.html for full NNX documentation, with more examples and a deeper dive.

Model Definition

In [10]:
from flax import nnx


class TitanicNNX(nnx.Module):
    def __init__(self, num_hidden_1: int, num_hidden_2: int, rngs: nnx.Rngs):
        self.linear1 = nnx.Linear(8, num_hidden_1, rngs=rngs)
        self.dropout = nnx.Dropout(0.01, rngs=rngs)
        self.relu = nnx.leaky_relu
        self.linear2 = nnx.Linear(num_hidden_1, num_hidden_2, rngs=rngs)
        self.linear3 = nnx.Linear(num_hidden_2, 1, use_bias=False, rngs=rngs)

    def __call__(self, x):
        x = self.linear1(x)
        x = self.dropout(x)
        x = self.relu(x)
        x = self.linear2(x)
        x = self.dropout(x)
        x = self.relu(x)
        out = self.linear3(x)
        return out

Initialize the Model

In [11]:
model = TitanicNNX(32, 16, rngs=nnx.Rngs(0))
nnx.display(model)

Training Loop Setup

In [12]:
import optax
from jax import Array


def train(
    model: nnx.Module,
    train_dataloader: DataLoader,
    eval_dataloader: DataLoader,
    num_epochs=int,
):
    optimizer = nnx.Optimizer(model, optax.adam(learning_rate=0.01))

    for epoch in (pbar := tqdm(range(num_epochs))):
        pbar.set_description(f"Epoch {epoch}")
        model.train()
        for batch in train_dataloader:
            train_step(model, optimizer, batch)

        pbar.set_postfix(
            train_accuracy=eval(model, train_dataloader),
            eval_accuracy=eval(model, eval_dataloader),
        )


@nnx.jit
def train_step(model: nnx.Module, optimizer: nnx.Optimizer, batch: Array):
    def loss_fn(model):
        logits = model(batch[0])
        loss = optax.sigmoid_binary_cross_entropy(logits.squeeze(), batch[1]).mean()
        return loss

    grad_fn = nnx.value_and_grad(loss_fn)
    loss, grads = grad_fn(model)
    optimizer.update(grads)


def eval(model: nnx.Module, eval_dataloader: DataLoader):
    total = 0
    num_correct = 0
    model.eval()
    for batch in eval_dataloader:
        res = eval_step(model, batch)
        total += res.shape[0]
        num_correct += jnp.sum(res)
    return num_correct / total


@nnx.jit
def eval_step(model: nnx.Module, batch: Array):
    logits = model(batch[0])
    logits = logits.squeeze()
    preds = jnp.round(nnx.sigmoid(logits))
    return preds == batch[1]

Initiate Training

In [13]:
train(model, train_dataloader_jax, eval_dataloader_jax, num_epochs=500)

Epoch 499: 100%|██████████| 500/500 [00:33<00:00, 15.11it/s, eval_accuracy=0.78089887, train_accuracy=0.81434596]


### Flax Linen (Functional) Neural Net

See https://flax-linen.readthedocs.io/en/latest/index.html for full Linen documentation, with more examples and a deeper dive.

In [14]:
from flax import linen as nn

initializer = jax.nn.initializers.lecun_normal()

# Model definition


# setup(approach)
class TitanicLinen(nn.Module):
    num_hidden_1: int
    num_hidden_2: int

    def setup(self):
        self.linear1 = nn.Dense(features=self.num_hidden_1, kernel_init=initializer)
        self.linear2 = nn.Dense(features=self.num_hidden_2, kernel_init=initializer)
        self.linear3 = nn.Dense(features=1, use_bias=False, kernel_init=initializer)
        self.dropout1 = nn.Dropout(0.01)
        self.dropout2 = nn.Dropout(0.01)

    def __call__(self, x, training):
        x = self.linear1(x)
        x = self.dropout1(x, deterministic=not training)
        x = nn.leaky_relu(x)
        x = self.linear2(x)
        x = self.dropout2(x, deterministic=not training)
        x = nn.leaky_relu(x)
        out = self.linear3(x)
        return out


## nn.compact approach
class TitanicLinenCompact(nn.Module):
    num_hidden_1: int
    num_hidden_2: int

    @nn.compact
    def __call__(self, x, training):
        x = nn.Dense(features=self.num_hidden_1, kernel_init=initializer)(x)
        x = nn.Dropout(0.01, deterministic=not training)(x)
        x = nn.leaky_relu(x)
        x = nn.Dense(features=self.num_hidden_2, kernel_init=initializer)(x)
        x = nn.Dropout(0.01, deterministic=not training)(x)
        x = nn.leaky_relu(x)
        out = nn.Dense(features=1, use_bias=False, kernel_init=initializer)(x)
        return out

In [15]:
rng = jax.random.PRNGKey(42)
new_rng, subkey, subdropout = jax.random.split(rng, num=3)

sample_data, sample_labels = next(iter(train_dataloader_jax))
model = TitanicLinenCompact(num_hidden_1=32, num_hidden_2=16)
params = model.init(subkey, sample_data, True)
logits = model.apply(params, sample_data, True, rngs={"dropout": subdropout})
logits.shape

(64, 1)

In [16]:
import optax
from flax.training import train_state

# Setup for train loop
optimizer = optax.adam(learning_rate=0.01)

state = train_state.TrainState.create(
    apply_fn=model.apply,
    params=params,
    tx=optimizer,
)

In [17]:
from jax import jit


# Train loop
def train(state, train_dataloader, eval_dataloader, subdropout, num_epochs):
    for epoch in (pbar := tqdm(range(num_epochs))):
        pbar.set_description(f"Epoch {epoch}")
        for batch in train_dataloader:
            state, loss = train_step(state, batch, subdropout)

        pbar.set_postfix(
            train_accuracy=eval(state, train_dataloader),
            eval_accuracy=eval(state, eval_dataloader),
        )

    return state


def eval(state, eval_dataloader):
    total = 0
    num_correct = 0
    for batch in eval_dataloader:
        res = eval_step(state, batch)
        total += res.shape[0]
        num_correct += jnp.sum(res)
    return num_correct / total


@jit
def train_step(state, batch, subdropout):
    def loss_fn(params):
        logits = state.apply_fn(params, batch[0], True, rngs={"dropout": subdropout})
        loss = optax.sigmoid_binary_cross_entropy(logits.squeeze(), batch[1]).mean()
        return loss

    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss


@jit
def eval_step(state, batch):
    logits = state.apply_fn(state.params, batch[0], False)
    logits = logits.squeeze()
    preds = jnp.round(nn.sigmoid(logits))
    return preds == batch[1]

In [18]:
state = train(
    state, train_dataloader_jax, eval_dataloader_jax, subdropout, num_epochs=500
)

Epoch 499: 100%|██████████| 500/500 [00:28<00:00, 17.65it/s, eval_accuracy=0.71910113, train_accuracy=0.8284107] 
