References: https://cloud.google.com/blog/products/ai-machine-learning/guide-to-jax-for-pytorch-developers

In [106]:
import torch
from torch.utils.data import Dataset, DataLoader, default_collate
import torch.nn as nn
import torch.optim as optim
import jax.numpy as jnp
from jax.tree_util import tree_map
from flax import nnx
from tqdm import tqdm
import optax
import pandas as pd
from sklearn.model_selection import train_test_split

In [107]:
# Dataset Definition
class TitanicDataset(Dataset):
    def __init__(self, samples, labels):
        self.df = samples
        self.labels = labels

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        x = torch.tensor(self.df.iloc[idx].values, dtype=torch.float32)
        y = torch.tensor(self.labels.iloc[idx], dtype=torch.float32)
        return x, y


def numpy_collate(batch):
    return tree_map(jnp.asarray, default_collate(batch))

In [108]:
# Model definition in PyTorch
class TitanicNeuralNet(nn.Module):
    def __init__(self, num_hidden_1, num_hidden_2):
        super().__init__()
        self.linear1 = nn.Linear(8, num_hidden_1)
        self.dropout = nn.Dropout(0.01)
        self.relu = nn.LeakyReLU()
        self.linear2 = nn.Linear(num_hidden_1, num_hidden_2)
        self.linear3 = nn.Linear(num_hidden_2, 1, bias=False)

    def forward(self, x):
        x = self.linear1(x)
        x = self.dropout(x)
        x = self.relu(x)
        x = self.linear2(x)
        x = self.dropout(x)
        x = self.relu(x)
        out = self.linear3(x)
        return out

In [109]:
# Model definition in Flax
class TitanicNNX(nnx.Module):
    def __init__(self, num_hidden_1, num_hidden_2, rngs: nnx.Rngs):
        self.linear1 = nnx.Linear(8, num_hidden_1, rngs=rngs)
        self.dropout = nnx.Dropout(0.01, rngs=rngs)
        self.relu = nnx.leaky_relu
        self.linear2 = nnx.Linear(num_hidden_1, num_hidden_2, rngs=rngs)
        self.linear3 = nnx.Linear(num_hidden_2, 1, use_bias=False, rngs=rngs)

    def __call__(self, x):
        x = self.linear1(x)
        x = self.dropout(x)
        x = self.relu(x)
        x = self.linear2(x)
        x = self.dropout(x)
        x = self.relu(x)
        out = self.linear3(x)
        return out

In [110]:
# PyTorch training loop
def train_torch(model, train_dataloader, eval_dataloader, num_epochs):
    optimizer = optim.Adam(model.parameters(), lr=0.01)
    criterion = torch.nn.BCEWithLogitsLoss()
    for epoch in (pbar := tqdm(range(num_epochs))):
        pbar.set_description(f"Epoch {epoch}")
        model.train()
        for batch, labels in train_dataloader:
            optimizer.zero_grad()
            logits = model(batch)
            loss = criterion(logits.squeeze(), labels)
            loss.backward()
            optimizer.step()

        pbar.set_postfix(
            train_accuracy=eval_torch(model, train_dataloader),
            eval_accuracy=eval_torch(model, eval_dataloader),
        )


def eval_torch(model, eval_dataloader):
    model.eval()
    num_correct = 0
    num_samples = 0
    for batch, labels in eval_dataloader:
        logits = model(batch)
        preds = torch.round(torch.sigmoid(logits))
        num_correct += (preds.squeeze() == labels).sum().item()
        num_samples += labels.shape[0]
    return num_correct / num_samples

In [111]:
# NNX training loop
def train_nnx(model, train_dataloader, eval_dataloader, num_epochs):
    optimizer = nnx.ModelAndOptimizer(model, optax.adam(learning_rate=0.01))
    for epoch in (pbar := tqdm(range(num_epochs))):
        pbar.set_description(f"Epoch {epoch}")
        model.train()
        for batch in train_dataloader:
            train_step(model, optimizer, batch)
        pbar.set_postfix(
            train_accuracy=eval_nnx(model, train_dataloader),
            eval_accuracy=eval_nnx(model, eval_dataloader),
        )


@nnx.jit
def train_step(model, optimizer, batch):
    def loss_fn(model):
        logits = model(batch[0])
        loss = optax.sigmoid_binary_cross_entropy(logits.squeeze(), batch[1]).mean()
        return loss

    grad_fn = nnx.value_and_grad(loss_fn)
    loss, grads = grad_fn(model)
    optimizer.update(grads)


def eval_nnx(model, eval_dataloader):
    model.eval()
    total = 0
    num_correct = 0
    for batch in eval_dataloader:
        res = eval_step(model, batch)
        total += res.shape[0]
        num_correct += jnp.sum(res)
    return num_correct / total


@nnx.jit
def eval_step(model, batch):
    logits = model(batch[0])
    logits = logits.squeeze()
    preds = jnp.round(nnx.sigmoid(logits))
    return preds == batch[1]

In [112]:
# Read the dataset
df = pd.read_csv("titanic_dataset.csv")

In [113]:
df.head()

Unnamed: 0,PassengerId,Survived,Pclass,Name,Sex,Age,SibSp,Parch,Ticket,Fare,Cabin,Embarked
0,1,0,3,"Braund, Mr. Owen Harris",male,22.0,1,0,A/5 21171,7.25,,S
1,2,1,1,"Cumings, Mrs. John Bradley (Florence Briggs Th...",female,38.0,1,0,PC 17599,71.2833,C85,C
2,3,1,3,"Heikkinen, Miss. Laina",female,26.0,0,0,STON/O2. 3101282,7.925,,S
3,4,1,1,"Futrelle, Mrs. Jacques Heath (Lily May Peel)",female,35.0,1,0,113803,53.1,C123,S
4,5,0,3,"Allen, Mr. William Henry",male,35.0,0,0,373450,8.05,,S


In [114]:
# Preprocessing
df = df.drop(["PassengerId", "Name", "Ticket", "Cabin"], axis=1)

# Encode Sex as binary
df["Sex"] = df["Sex"].map({"female": 0, "male": 1})

# Encode Embarked as ordinal integers
df["Embarked"] = df["Embarked"].map({"C": 0, "Q": 1, "S": 2})

# Treat Missing Values
df["Age"] = df["Age"].fillna(df["Age"].median())
df["Embarked"] = df["Embarked"].fillna(df["Embarked"].mode()[0])

# Feature Engineering
df["FamilySize"] = df["SibSp"] + df["Parch"]

In [115]:
# Train / Test split
y = df["Survived"]
X = df.drop("Survived", axis=1)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, stratify=y)

In [116]:
# Create Train dataset, dataloader from a pandas dataframe
train_dataset = TitanicDataset(X_train, y_train)
train_dataloader_jax = DataLoader(
    train_dataset, batch_size=64, shuffle=True, collate_fn=numpy_collate
)
train_dataloader_torch = DataLoader(train_dataset, batch_size=64, shuffle=True)

# Create Eval dataset, dataloader from a pandas dataframe
eval_dataset = TitanicDataset(X_test, y_test)
eval_dataloader_jax = DataLoader(
    eval_dataset, batch_size=64, shuffle=False, collate_fn=numpy_collate
)
eval_dataloder_torch = DataLoader(eval_dataset, batch_size=64, shuffle=False)

# PyTorch Initilization
torch_model = TitanicNeuralNet(num_hidden_1=32, num_hidden_2=16)

# Flax NNX Initilization
flax_model = TitanicNNX(num_hidden_1=32, num_hidden_2=16, rngs=nnx.Rngs(0))

In [122]:
# Train the PyTorch model
tqdm.write("Training PyTorch Model")
train_torch(
    model=torch_model,
    train_dataloader=train_dataloader_torch,
    eval_dataloader=eval_dataloder_torch,
    num_epochs=10,
)

# Train the Flax model
tqdm.write("Training Flax Model")
train_nnx(
    model=flax_model,
    train_dataloader=train_dataloader_jax,
    eval_dataloader=eval_dataloader_jax,
    num_epochs=10,
)

Training PyTorch Model


Epoch 9: 100%|██████████| 10/10 [00:01<00:00,  8.36it/s, eval_accuracy=0.821, train_accuracy=0.823]


Training Flax Model


Epoch 9: 100%|██████████| 10/10 [00:01<00:00,  5.24it/s, eval_accuracy=0.8324022, train_accuracy=0.81179774]
