<a href="https://colab.research.google.com/github/ShaliniAnandaPhD/PIXEL-PIONEERS-TUTORIALS/blob/main/JAX_Tutorial_3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Nutrient Recommendation System:

Generate synthetic dataset of individuals with features like age, weight, activity levels etc and target daily recommended intakes for nutrients like protein, vitamins, minerals.
Explore JAX for OOP-style modular neural network construction
Build a multi-task feedforward neural network that predicts recommended nutrient intakes
Train with JAX accelerated gradients and Optimizer API
Evaluate predictions compared to true synthetic labels
Discuss enhancements like custom neural network layers, regularization, ensembling

Generate a Synthetic Dataset:

Create a function to generate synthetic data with features like age, weight, activity levels, etc., and target daily recommended intakes for nutrients.
Ensure the data is diverse and covers a realistic range of values.

Explore JAX for OOP-Style Modular Neural Network Construction:

Define a class for the neural network using JAX.
Implement methods for forward pass and any custom layers or functions.

Build a Multi-Task Feedforward Neural Network:

This network will predict multiple outputs (nutrient intakes).
Use appropriate loss functions and metrics for multi-task learning.

Training with JAX Accelerated Gradients and Optimizer API:

Utilize JAX's autodiff capabilities for gradient computation.
Use an optimizer from JAX's optimizer module for updating model parameters.


Generate Synthetic Dataset
This cell generates your synthetic dataset with given features and targets.

In [23]:
def generate_data(num_samples=1000):
    np.random.seed(0)
    ages = np.random.randint(18, 80, num_samples)
    weights = np.random.normal(70, 15, num_samples)
    activity_levels = np.random.choice(['sedentary', 'moderate', 'active'], num_samples)
    health_conditions = np.random.choice(['healthy', 'precondition', 'conditioned'], num_samples)

    protein_intake = weights * 0.8 + ages * 0.3
    vitamin_intake = 500 + (weights - 70) * 10
    mineral_intake = 200 + (ages / 80) * 600

    features = np.column_stack((ages, weights, activity_levels, health_conditions))
    targets = np.column_stack((protein_intake, vitamin_intake, mineral_intake))
    return features, targets

X, y = generate_data()


Preprocess Data
This cell preprocesses the data, ensuring that the features are scaled and encoded appropriately.

In [24]:
preprocess_pipeline = ColumnTransformer([
    ('scale', StandardScaler(), [0, 1]),  # Scale age and weight
    ('onehot', OneHotEncoder(sparse_output=False), [2, 3]),  # One-hot encode lifestyle and health conditions
])

def preprocess(X):
    return preprocess_pipeline.fit_transform(X).astype(np.float32)

X_processed = preprocess(X)


Define the Neural Network Class

In [25]:
class NutrientModel(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(128)(x)
        x = nn.relu(x)
        x = nn.Dense(64)(x)
        x = nn.relu(x)
        x = nn.Dense(3)(x)  # 3 outputs for protein, vitamin, mineral intakes
        return x


Training Setup
This cell sets up the model, parameters, and optimizer.

In [26]:
model = NutrientModel()
key = jax.random.PRNGKey(0)
sample_input = X_processed[:1]
params = model.init(key, sample_input)['params']
optimizer = optax.adam(1e-3)
state = optimizer.init(params)


Training Step Function
This cell defines the function for a training step.

In [27]:
def train_step(model, params, state, x, y, optimizer):
    def loss_fn(params):
        preds = model.apply({'params': params}, x)
        return jnp.mean((preds - y) ** 2)

    grads = jax.grad(loss_fn)(params)
    updates, new_state = optimizer.update(grads, state)
    new_params = optax.apply_updates(params, updates)
    return new_params, new_state


Training Loop
This cell contains the training loop.

In [28]:
for epoch in range(10):
    for i in range(0, len(X_processed), 32):
        xi = X_processed[i:i + 32]
        yi = y[i:i + 32]
        params, state = train_step(model, params, state, xi, yi, optimizer)


Data Splitting
This cell splits the data into training and validation sets.

In [29]:
X_train, X_val, y_train, y_val = train_test_split(X_processed, y, test_size=0.2, random_state=42)


Evaluation Function
This cell defines the function to evaluate the model.



In [30]:
def evaluate_model(model, params, X, y):
    preds = jax.vmap(lambda x: model.apply({'params': params}, x))(X)
    return jnp.mean((preds - y) ** 2)


Model Evaluation
This cell evaluates the model after each training epoch.

In [31]:
for epoch in range(10):
    for i in range(0, len(X_train), 32):
        xi = X_train[i:i + 32]
        yi = y_train[i:i + 32]
        params, state = train_step(model, params, state, xi, yi, optimizer)

    val_loss = evaluate_model(model, params, X_val, y_val)
    print(f"Epoch {epoch + 1}: Validation Loss: {val_loss:.4f}")


Epoch 1: Validation Loss: 7529.7271
Epoch 2: Validation Loss: 7394.5317
Epoch 3: Validation Loss: 7256.1152
Epoch 4: Validation Loss: 7113.0967
Epoch 5: Validation Loss: 6963.6729
Epoch 6: Validation Loss: 6806.1211
Epoch 7: Validation Loss: 6638.5610
Epoch 8: Validation Loss: 6459.1846
Epoch 9: Validation Loss: 6266.1938
Epoch 10: Validation Loss: 6057.4268


In [32]:
def demo_prediction(model, params, sample_features, preprocess_pipeline):
    sample_features_array = np.array([sample_features], dtype=object)
    sample_features_processed = preprocess_pipeline.transform(sample_features_array).astype(np.float32)

    print("Processed features shape:", sample_features_processed.shape)  # Debugging line
    if sample_features_processed.shape[1] != 8:
        raise ValueError(f"Shape mismatch after preprocessing. Expected 8 features, got {sample_features_processed.shape[1]}")

    predicted_intake = model.apply({'params': params}, sample_features_processed)
    return predicted_intake

sample_features = [30, 70, 'active', 'healthy']
predicted_intake = demo_prediction(model, params, sample_features, preprocess_pipeline)
print("Predicted Nutrient Intakes:", predicted_intake)


Processed features shape: (1, 8)
Predicted Nutrient Intakes: [[ 60.026802 420.9478   457.3685  ]]
