<a href="https://colab.research.google.com/github/Suad0/Suad0/blob/main/jax_flax_MAML.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [22]:
!pip install jax jaxlib optax tensorflow-datasets



In [24]:
# Set random seed for reproducibility
import jax.random as random
import jax
import jax.numpy as jnp
from jax import grad, vmap, random, jit
import tensorflow_datasets as tfds
import optax
import numpy as np

rng_key = random.PRNGKey(0)

# Hyperparameters
inner_steps = 5
inner_lr = 0.1
outer_lr = 0.001
k_shot = 5  # 5-shot learning
n_way = 5   # 5-way classification
batch_tasks = 4  # Number of tasks per meta-batch
img_size = (28, 28, 1)

import tensorflow as tf

def load_omniglot():
    ds, info = tfds.load('omniglot', split='train', with_info=True, as_supervised=True)

    # Fetch all data as NumPy arrays
    images_np = []
    labels_np = []
    for image, label in tfds.as_numpy(ds):
        images_np.append(image)
        labels_np.append(label)

    images_np = jnp.array(images_np)
    labels_np = jnp.array(labels_np)

    def preprocess(image):
        image = image.astype(jnp.float32) / 255.0
        image = jnp.resize(image, img_size)
        return image

    # Apply preprocessing to the NumPy arrays
    images_processed = vmap(preprocess)(images_np)

    # Create a JAX-compatible dataset (list of tuples or similar)
    # For simplicity, let's return the processed arrays directly for now
    return images_processed, labels_np, info

def init_model_params(rng, input_shape):
    def conv_layer(rng, in_channels, out_channels):
        w_key, b_key = random.split(rng)
        # Fixed: Use proper conv weight shape (out_channels, in_channels, kernel_h, kernel_w)
        w = random.normal(w_key, (out_channels, in_channels, 3, 3)) * 0.01
        b = jnp.zeros((out_channels,))
        return {'w': w, 'b': b}

    params = []
    rng, *layer_rngs = random.split(rng, 4)
    params.append(conv_layer(layer_rngs[0], input_shape[-1], 64))
    params.append(conv_layer(layer_rngs[1], 64, 64))
    params.append({'w': random.normal(layer_rngs[2], (64, n_way)) * 0.01, 'b': jnp.zeros((n_way,))})
    return params

# Fixed forward pass
def forward(params, x):
    # x should have shape (batch_size, height, width, channels)
    # Ensure x has the right shape
    if x.ndim == 3:
        x = x[None, ...]  # Add batch dimension if missing

    # Conv layer 1
    x = jax.lax.conv_general_dilated(
        x, params[0]['w'],
        window_strides=(1, 1),
        padding='SAME',
        dimension_numbers=('NHWC', 'OIHW', 'NHWC')
    )
    x = x + params[0]['b'][None, None, None, :]  # Add bias
    x = jax.nn.relu(x)

    # Conv layer 2
    x = jax.lax.conv_general_dilated(
        x, params[1]['w'],
        window_strides=(1, 1),
        padding='SAME',
        dimension_numbers=('NHWC', 'OIHW', 'NHWC')
    )
    x = x + params[1]['b'][None, None, None, :]  # Add bias
    x = jax.nn.relu(x)

    # Global average pooling
    x = jnp.mean(x, axis=(1, 2))  # Shape: (batch_size, channels)

    # Final linear layer
    logits = jnp.dot(x, params[2]['w']) + params[2]['b']
    return logits

# Loss function
def loss_fn(params, x, y):
    logits = forward(params, x)
    return -jnp.mean(jax.nn.log_softmax(logits)[jnp.arange(y.shape[0]), y])

# Inner loop update
@jit
def inner_update(params, x_supp, y_supp):
    grads = grad(loss_fn)(params, x_supp, y_supp)
    updated_params = []
    for p, g in zip(params, grads):
        if isinstance(p, dict):
            updated_p = {}
            for key in p:
                updated_p[key] = p[key] - inner_lr * g[key]
            updated_params.append(updated_p)
        else:
            updated_params.append(p - inner_lr * g)
    return updated_params

# MAML inner loop
@jit
def maml_inner(params, x_supp, y_supp, x_query, y_query):
    adapted_params = params
    for _ in range(inner_steps):
        adapted_params = inner_update(adapted_params, x_supp, y_supp)
    return loss_fn(adapted_params, x_query, y_query)

# Meta-loss over a batch of tasks
@jit
def meta_loss(params, batch):
    losses = vmap(maml_inner, in_axes=(None, 0, 0, 0, 0))(
        params, batch['x_supp'], batch['y_supp'], batch['x_query'], batch['y_query']
    )
    return jnp.mean(losses)

# Meta-gradient computation
meta_grad = jit(grad(meta_loss))

# Training loop
def train_step(params, batch, optimizer, opt_state):
    grads = meta_grad(params, batch)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state

# Prepare task batch
def prepare_task_batch(images, labels, rng, num_tasks, k_shot, n_way):
    batch = {'x_supp': [], 'y_supp': [], 'x_query': [], 'y_query': []}
    num_classes = int(labels.max()) + 1  # Assuming labels are 0-indexed

    # Group images and labels by class
    class_images = [[] for _ in range(num_classes)]
    for i, label in enumerate(labels):
        class_images[int(label)].append(images[i])

    # Convert to arrays and filter out empty classes
    class_images = [jnp.array(c) for c in class_images if len(c) >= 2 * k_shot]
    available_classes = len(class_images)

    if available_classes < n_way:
        raise ValueError(f"Not enough classes with sufficient samples. Need {n_way}, got {available_classes}")

    for _ in range(num_tasks):
        rng, class_rng = jax.random.split(rng)
        # Sample n_way classes
        sampled_class_indices = jax.random.choice(
            class_rng, available_classes, shape=(n_way,), replace=False
        )

        task_x_supp = []
        task_y_supp = []
        task_x_query = []
        task_y_query = []

        for new_label, class_idx in enumerate(sampled_class_indices):
            rng, sample_rng = jax.random.split(rng)
            class_data = class_images[int(class_idx)]

            # Sample k_shot support and k_shot query examples from each class
            class_indices = jax.random.permutation(sample_rng, len(class_data))
            supp_indices = class_indices[:k_shot]
            query_indices = class_indices[k_shot:2 * k_shot]

            task_x_supp.append(class_data[supp_indices])
            task_y_supp.append(jnp.full((k_shot,), new_label))  # Use new labels 0, 1, 2, ..., n_way-1
            task_x_query.append(class_data[query_indices])
            task_y_query.append(jnp.full((k_shot,), new_label))

        batch['x_supp'].append(jnp.concatenate(task_x_supp, axis=0))
        batch['y_supp'].append(jnp.concatenate(task_y_supp, axis=0))
        batch['x_query'].append(jnp.concatenate(task_x_query, axis=0))
        batch['y_query'].append(jnp.concatenate(task_y_query, axis=0))

    return {k: jnp.array(v) for k, v in batch.items()}

# Main training
def main(rng_key):
    images_np, labels_np, _ = load_omniglot()
    params = init_model_params(rng_key, img_size)
    optimizer = optax.adam(outer_lr)
    opt_state = optimizer.init(params)

    for step in range(10):
        rng_key, subkey = random.split(rng_key)
        try:
            batch = prepare_task_batch(images_np, labels_np, subkey, batch_tasks, k_shot, n_way)
            params, opt_state = train_step(params, batch, optimizer, opt_state)
            if step % 100 == 0:
                loss = meta_loss(params, batch)
                print(f"Step {step}, Meta-Loss: {loss:.4f}")
        except ValueError as e:
            print(f"Error at step {step}: {e}")
            break

if __name__ == "__main__":
    main(rng_key)

Step 0, Meta-Loss: 1.6094
