# FLAX Convolution Neural Network Example - Interactive API

Run this jupyter notebook on a virtual environment.

In [1]:
import os
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'

In [1]:
# !pip install --upgrade -q "jax[cuda11_cudnn805]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# !pip install --upgrade -q git+https://github.com/google/flax.git

In [2]:
# !pip install flax==0.4.1 ml_collections optax jax==0.3.13 -q
# # jaxlib==0.3.10 -q

GPU version of JAX. Pick the jax version compatible with the CUDA and cuDNN pre-installed.

In [3]:
# !pip install --upgrade pip # Careful with the pip upgrade, it may cause a package dependency related problems during OpenFL workflow execution.

# Installs the wheel compatible with CUDA 11 and cuDNN 8.2 or newer.
# Note: wheels only available on linux.
# !pip install -U jaxlib==0.3.10+cuda11.cudnn82 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

# Installs the wheel compatible with Cuda >= 11.4 and cudnn >= 8.2
# !pip install -U jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

# Installs the wheel compatible with Cuda >= 11.1 and cudnn >= 8.0.5
# !pip install "jax[cuda11_cudnn805]" -f https://storage.googleapis.com/jax-releases/jax_releases.html

In [4]:
# Without either of the below flags, JAX XLA raised CUDA_OUT_OF_MEMORY exception.
# JAX XLA pre-allocates 90% of the GPU at start

# Below flag to restrict max GPU allocation to 80%
# %env XLA_PYTHON_CLIENT_MEM_FRACTION=.8

# OR

# set XLA_PYTHON_CLIENT_PREALLOCATE to false to incrementally allocate GPU memory as and when required. But can take entire GPU by the end.
# %env XLA_PYTHON_CLIENT_PREALLOCATE=false


env: XLA_PYTHON_CLIENT_MEM_FRACTION=.8


In [5]:
# %pip install tensorflow==2.8.1


In [6]:
# %pip install tensorflow_datasets ml_collections


In [2]:
# %env

{'SHELL': '/bin/bash',
 'COLORTERM': 'truecolor',
 'no_proxy': 'localhost,127.0.0.0/8,10.0.0.0/8,.intel.com',
 'TERM_PROGRAM_VERSION': '1.69.2',
 'LANGUAGE': 'en_IN:en',
 'PWD': '/home/sunilach/openfl/forked-intel-openfl/openfl-tutorials/interactive_api/Flax_CNN_CIFAR/workspace',
 'LOGNAME': 'sunilach',
 'XDG_SESSION_TYPE': 'tty',
 'ftp_proxy': 'ftp://proxy.iind.intel.com:1080/',
 'MOTD_SHOWN': 'pam',
 'HOME': '/home/sunilach',
 'LANG': 'en_US.UTF-8',
 'LS_COLORS': 'rs=0:di=01;34:ln=01;36:mh=00:pi=40;33:so=01;35:do=01;35:bd=40;33;01:cd=40;33;01:or=40;31;01:mi=00:su=37;41:sg=30;43:ca=30;41:tw=30;42:ow=34;42:st=37;44:ex=01;32:*.tar=01;31:*.tgz=01;31:*.arc=01;31:*.arj=01;31:*.taz=01;31:*.lha=01;31:*.lz4=01;31:*.lzh=01;31:*.lzma=01;31:*.tlz=01;31:*.txz=01;31:*.tzo=01;31:*.t7z=01;31:*.zip=01;31:*.z=01;31:*.dz=01;31:*.gz=01;31:*.lrz=01;31:*.lz=01;31:*.lzo=01;31:*.xz=01;31:*.zst=01;31:*.tzst=01;31:*.bz2=01;31:*.bz=01;31:*.tbz=01;31:*.tbz2=01;31:*.tz=01;31:*.deb=01;31:*.rpm=01;31:*.jar=01;31:*

In [2]:
from flax import linen as nn
from flax.metrics import tensorboard
from flax.training import train_state
import jax
import jax.numpy as jnp
import ml_collections
import optax
import tensorflow_datasets as tfds
import logging

  PyTreeDef = type(jax.tree_structure(None))
  from .autonotebook import tqdm as notebook_tqdm


In [4]:
# editor_relpaths = ('configs/default.py', 'train.py')

In [5]:
# @jax.jit
# def apply_model(state, images, labels):
#     """Computes gradients, loss and accuracy for a single batch."""

#     def loss_fn(params):
#         logits = state.apply_fn({'params': params}, images)
#         one_hot = jax.nn.one_hot(labels, 10)
#         loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
#         return loss, logits

#     grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
#     (loss, logits), grads = grad_fn(state.params)
#     accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
#     return grads, loss, accuracy


# @jax.jit
# def update_model(state, grads):
#     return state.apply_gradients(grads=grads)

In [6]:
# def train_epoch(state, train_ds, batch_size, rng):
#     """Train for a single epoch."""
#     train_ds_size = len(train_ds['image'])
#     steps_per_epoch = train_ds_size // batch_size

#     perms = jax.random.permutation(rng, len(train_ds['image']))
#     perms = perms[:steps_per_epoch * batch_size]  # skip incomplete batch
#     perms = perms.reshape((steps_per_epoch, batch_size))

#     epoch_loss = []
#     epoch_accuracy = []

#     for perm in perms:
#         batch_images = train_ds['image'][perm, ...]
#         batch_labels = train_ds['label'][perm, ...]
#         grads, loss, accuracy = apply_model(state, batch_images, batch_labels)
#         state = update_model(state, grads)
#         epoch_loss.append(loss)
#         epoch_accuracy.append(accuracy)
        
#     train_loss = np.mean(epoch_loss)
#     train_accuracy = np.mean(epoch_accuracy)
#     return state, train_loss, train_accuracy

In [7]:
# def train_and_evaluate(config: ml_collections.ConfigDict,
#                        workdir: str) -> train_state.TrainState:
#     """Execute model training and evaluation loop.
#     Args:
#         config: Hyperparameter configuration for training and evaluation.
#         workdir: Directory where the tensorboard summaries are written to.
#     Returns:
#         The train state (which includes the `.params`).
#     """
    
#     train_ds, test_ds = get_datasets()
#     rng = jax.random.PRNGKey(0)

#     summary_writer = tensorboard.SummaryWriter(workdir)
#     summary_writer.hparams(dict(config))

#     rng, init_rng = jax.random.split(rng)
#     state = create_train_state(init_rng, config)

#     for epoch in range(1, config.num_epochs + 1):
#         rng, input_rng = jax.random.split(rng)
#         state, train_loss, train_accuracy = train_epoch(state, train_ds,
#                                                         config.batch_size,
#                                                         input_rng)
#         _, test_loss, test_accuracy = apply_model(state, test_ds['image'],
#                                                   test_ds['label'])

#         logging.info(
#             'epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f, test_loss: %.4f, test_accuracy: %.2f'
#             % (epoch, train_loss, train_accuracy * 100, test_loss,
#                test_accuracy * 100))

#         summary_writer.scalar('train_loss', train_loss, epoch)
#         summary_writer.scalar('train_accuracy', train_accuracy, epoch)
#         summary_writer.scalar('test_loss', test_loss, epoch)
#         summary_writer.scalar('test_accuracy', test_accuracy, epoch)

#     summary_writer.flush()
#     return state

In [8]:
# #Using jit decorator for GPU acceleration for entire function
# @jax.jit
# def train_step(optimizer, batch):
#     def loss_fn(model):
#         preds = model(batch['image'])
#         loss = cross_entropy_loss(preds, batch['label'])
#         return loss, preds
#     grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
#     (_, preds), grad = grad_fn(optimizer.target)
#     optimizer = optimizer.apply_gradient(grad)
#     return optimizer

# @jax.jit
# def eval_step(model, batch):
#     preds = model(batch['image'])
#     return compute_metrics(preds, batch['label'])

# def eval_model(model, test_ds):
#     metrics = eval_step(model, test_ds)
#     metrics = jax.device_get(metrics)
#     summary = jax.tree_map(lambda x: x.item(), metrics)
#     return summary['loss'], summary['accuracy']

In [9]:
# def get_datasets():
#     """Load MNIST train and test datasets into memory."""
#     ds_builder = tfds.builder('cifar10')
#     ds_builder.download_and_prepare()
#     train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
#     test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
#     train_ds['image'] = jnp.float32(train_ds['image']) / 255.
#     test_ds['image'] = jnp.float32(test_ds['image']) / 255.
#     return train_ds, test_ds

In [10]:
# def create_train_state(rng, config):
#     """Creates initial `TrainState`."""
#     cnn = CNN()
#     params = cnn.init(rng, jnp.ones([1, 28, 28, 3]))['params']
#     tx = optax.sgd(config.learning_rate, config.momentum)
#     return train_state.TrainState.create(
#       apply_fn=cnn.apply, params=params, tx=tx)

In [11]:
# class CNN(nn.Module):
#     """A simple CNN model."""
    
#     @nn.compact
#     def __call__(self, x):
#         x = nn.Conv(features=32, kernel_size=(3, 3))(x)
#         x = nn.relu(x)
#         x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
#         x = nn.Conv(features=64, kernel_size=(3, 3))(x)
#         x = nn.relu(x)
#         x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
#         x = x.reshape((x.shape[0], -1))  # flatten
#         x = nn.Dense(features=256)(x)
#         x = nn.relu(x)
#         x = nn.Dense(features=10)(x)
#         return x

In [12]:
# train_ds, test_ds = get_datasets()

2022-08-03 02:45:28.970445: W tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc:42] Overriding orig_value setting because the TF_FORCE_GPU_ALLOW_GROWTH environment variable is set. Original config value was 0.


In [13]:
import numpy as np
import matplotlib.pyplot as plt

In [14]:
# # Helper functions for images.

# def show_img(img, ax=None, title=None):
#     """Shows a single image."""
#     if ax is None:
#         ax = plt.gca()
#         ax.imshow(img[..., 0], cmap='gray')
#         ax.set_xticks([])
#         ax.set_yticks([])
#     if title:
#         ax.set_title(title)

# def show_img_grid(imgs, titles):
#     """Shows a grid of images."""
#     n = int(np.ceil(len(imgs)**.5))
#     _, axs = plt.subplots(n, n, figsize=(3 * n, 3 * n))
#     for i, (img, title) in enumerate(zip(imgs, titles)):
#         show_img(img, axs[i // n][i % n], title)

In [15]:
# show_img_grid(
#     [train_ds['image'][idx] for idx in range(25)],
#     [f'label={train_ds["label"][idx]}' for idx in range(25)],
# )

In [16]:
# from configs import default as config_lib
# config = config_lib.get_config()

In [17]:
# config.num_epochs = 3
# models = {}
# for momentum in (0.8, 0.9, 0.95):
#     name = f'momentum={momentum}'
#     config.momentum = momentum
#     state = train_and_evaluate(config, workdir=f'./models/{name}')
#     models[name] = state.params

  abs_value_flat = jax.tree_leaves(abs_value)
  value_flat = jax.tree_leaves(value)


In [18]:
# # Find all mistakes in testset.
# logits = CNN().apply({'params': state.params}, test_ds['image'])


In [19]:
# error_idxs, = jnp.where(test_ds['label'] != logits.argmax(axis=1))
# len(error_idxs) / len(logits)

0.0137

In [20]:
# show_img_grid(
#     [test_ds['image'][idx] for idx in error_idxs[:25]],
#     [f'pred={logits[idx].argmax()}' for idx in error_idxs[:25]],
# )

NameError: name 'show_img_grid' is not defined

In [None]:
# class LinearRegression:
#     def __init__(self, n_feat: int) -> None:
#         self.weights = jnp.ones(n_feat)
    
#     def mse(self, X, y) -> float:
#         return mse_loss_function1(self.weights, X, y)
 
#     def predict(self, X):
#         return jnp.dot(X, self.weights)
    
#     def fit(self, X, Y, n_epochs : int, learning_rate : int, silent : bool) -> None:
        
#         # Speed up weight updates with consecutive calls to jitted `update` function.
#         update_weights = jax.jit(update)
        
#         start_time = time.time()
#         print('Training Loss at start: ', self.mse(X, Y))
#         for i in range(n_epochs):
#             self.weights = update_weights(self.weights, X, Y, learning_rate)
#             if i % int(n_epochs/10) == 0 and not silent:
#                 print(str(i), 'Training Loss: ', self.mse(X, Y))

#         print("--- %s seconds ---" % (time.time() - start_time))

    

In [None]:
# _, init_params = CNN().init_by_shape(random.PRNGKey(0), [((1, 32, 32, 3), jnp.float32)])
# model = nn.Model(CNN, init_params)
# optimizer = optim.Adam(learning_rate=learning_rate, beta1=beta, beta2 = beta_2).create(model)

#Jit Pre-compilation

In [None]:
# train_step(optimizer, train_ds)

In [None]:
# eval_model(optimizer.target, test_ds)

In [None]:
# def train(train_ds, test_ds, model, optimizer):

#     batch_size = 128
#     num_epochs = 10
#     learning_rate = 0.001
#     beta = 0.9
#     beta_2 = 0.999
#     loss = 0
#     accuracy = 0
    
#     start_time = time.monotonic()
    
#     for epoch in range(1, num_epochs + 1):
#         train_time = 0
#         start_time_3 = time.monotonic()
#         batch_gen = tfds.as_numpy(train_ds)
#         for batch in batch_gen:
#             start_time_step = time.monotonic()
#             optimizer = train_step(optimizer, batch)
#             train_time += time.monotonic() - start_time_step
            
#         flax_step = time.monotonic() - start_time_3
        
#         start_time_2 = time.monotonic()
#         loss, accuracy = eval_model(optimizer.target, test_ds)
#         flax_inf = time.monotonic() - start_time_2
        
#         print('eval epoch: %d, epoch: %.2fs, actual_training: %.2fs, validation: %.2fs, loss: %.4f, accuracy: %.2f' % 
#               (epoch, flax_step, train_time, flax_inf, loss, accuracy * 100))
        
#     flax_time = time.monotonic() - start_time
#     return optimizer, flax_time, accuracy, flax_inf

In [None]:
# _, flax_time, flax_acc, flax_inf = train(train_ds, test_ds, model, optimizer)

NEW IMPLEMENTATION


In [4]:
import tensorflow as tf

In [5]:
train_ds = tfds.load('cifar10', split=tfds.Split.TRAIN)
train_ds = train_ds.map(lambda x: {'image': tf.cast(x['image'], tf.float32) / 255.,
                                     'label': tf.cast(x['label'], tf.int32)})
train_ds = train_ds.cache().shuffle(1000)
tmp_train = train_ds.batch(16)
train_ds = train_ds.batch(128)


test_ds = tfds.as_numpy(tfds.load(
      'cifar10', split=tfds.Split.TEST, batch_size=-1))
test_ds = {'image': test_ds['image'].astype(jnp.float32) / 255.,
             'label': test_ds['label'].astype(jnp.int32)}

In [10]:
print(list(train_ds)[0])

{'image': <tf.Tensor: shape=(128, 32, 32, 3), dtype=float32, numpy=
array([[[[0.3647059 , 0.41568628, 0.09803922],
         [0.4       , 0.4509804 , 0.13333334],
         [0.41568628, 0.46666667, 0.15294118],
         ...,
         [0.40784314, 0.4627451 , 0.13333334],
         [0.43137255, 0.47843137, 0.19215687],
         [0.43137255, 0.4745098 , 0.2       ]],

        [[0.37254903, 0.41960785, 0.11764706],
         [0.37254903, 0.41960785, 0.11764706],
         [0.39215687, 0.4392157 , 0.14509805],
         ...,
         [0.39607844, 0.44313726, 0.14509805],
         [0.39607844, 0.4392157 , 0.16078432],
         [0.4       , 0.44313726, 0.16078432]],

        [[0.3882353 , 0.4392157 , 0.13333334],
         [0.34117648, 0.39215687, 0.07843138],
         [0.37254903, 0.42352942, 0.11372549],
         ...,
         [0.45490196, 0.5058824 , 0.21176471],
         [0.42352942, 0.47058824, 0.18039216],
         [0.4       , 0.44705883, 0.14901961]],

        ...,

        [[0.3529412 , 0.

In [11]:
class CNN(nn.Module):
    def apply(self, x):
        x = nn.Conv(features=32, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = nn.Conv(features=64, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))  # flatten
        x = nn.Dense(features=256)(x)
        x = nn.relu(x)
        x = nn.Dense(features=10)(x)
        x = nn.log_softmax(x)
        return x

In [12]:
def onehot(labels, num_classes=10):
    return (labels[..., None] == jnp.arange(num_classes)[None]).astype(jnp.float32)

def cross_entropy_loss(preds, labels):
    return -jnp.mean(jnp.sum(onehot(labels) * preds, axis=-1))
#We could also implement it for single element and vectorize later with vmap

def compute_metrics(preds, labels):
    return {'loss': cross_entropy_loss(preds, labels),
            'accuracy': jnp.mean(jnp.argmax(preds, -1) == labels)}

In [13]:
#Using jit decorator for GPU acceleration for entire function
@jax.jit
def train_step(optimizer, batch):
    def loss_fn(model):
        preds = model(batch['image'])
        loss = cross_entropy_loss(preds, batch['label'])
        return loss, preds
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (_, preds), grad = grad_fn(optimizer.target)
    optimizer = optimizer.apply_gradient(grad)
    return optimizer

@jax.jit
def eval_step(model, batch):
    preds = model(batch['image'])
    return compute_metrics(preds, batch['label'])

def eval_model(model, test_ds):
    metrics = eval_step(model, test_ds)
    metrics = jax.device_get(metrics)
    summary = jax.tree_map(lambda x: x.item(), metrics)
    return summary['loss'], summary['accuracy']

In [14]:
learning_rate = 0.001
beta = 0.9
beta_2 = 0.999

In [16]:
_, init_params = CNN.init_by_shape(random.PRNGKey(0), [((1, 32, 32, 3), jnp.float32)])
model = nn.Model(CNN, init_params)
optimizer = optim.Adam(learning_rate=learning_rate, beta1=beta, beta2 = beta_2).create(model)

AttributeError: "CNN" object has no attribute "init_by_shape"

In [None]:
#jit pre-compilation    
start_mini = time.monotonic()
train_step(optimizer, mini_batch)
mini_time = time.monotonic() - start_mini
print('mini_batch training: %.2fs' % mini_time)
    
start_mini_2 = time.monotonic()
eval_model(optimizer.target, test_ds)
mini_val_time = time.monotonic() - start_mini_2
print('mini_batch validation: %.2fs' % mini_val_time)

In [None]:
#Mini-batch after compilation
start_mini = time.monotonic()
train_step(optimizer, mini_batch)
mini_time = time.monotonic() - start_mini
print('mini_batch training: %.2fs' % mini_time)
    
start_mini_2 = time.monotonic()
eval_model(optimizer.target, test_ds)
mini_val_time = time.monotonic() - start_mini_2
print('mini_batch validation: %.2fs' % mini_val_time)

In [None]:
def train(train_ds, test_ds, model, optimizer):

    batch_size = 128
    num_epochs = 10
    learning_rate = 0.001
    beta = 0.9
    beta_2 = 0.999
    loss = 0
    accuracy = 0
    
    start_time = time.monotonic()
    
    for epoch in range(1, num_epochs + 1):
        train_time = 0
        start_time_3 = time.monotonic()
        batch_gen = tfds.as_numpy(train_ds)
        for batch in batch_gen:
            start_time_step = time.monotonic()
            optimizer = train_step(optimizer, batch)
            train_time += time.monotonic() - start_time_step
            
        flax_step = time.monotonic() - start_time_3
        
        start_time_2 = time.monotonic()
        loss, accuracy = eval_model(optimizer.target, test_ds)
        flax_inf = time.monotonic() - start_time_2
        
        print('eval epoch: %d, epoch: %.2fs, actual_training: %.2fs, validation: %.2fs, loss: %.4f, accuracy: %.2f' % 
              (epoch, flax_step, train_time, flax_inf, loss, accuracy * 100))
        
    flax_time = time.monotonic() - start_time
    return optimizer, flax_time, accuracy, flax_inf

In [None]:
_, flax_time, flax_acc, flax_inf = train(train_ds, test_ds, model, optimizer)


# JAX Linear Regression with federation

## Connect to a Federation

In [None]:
# Create a federation
from openfl.interface.interactive_api.federation import Federation

# please use the same identificator that was used in signed certificate
client_id = 'frontend'
director_node_fqdn = 'localhost'
director_port = 50050

federation = Federation(
    client_id=client_id,
    director_node_fqdn=director_node_fqdn,
    director_port=director_port,
    tls=False
)

In [None]:
shard_registry = federation.get_shard_registry()
shard_registry

### Initialize Data Interface

In [None]:
from openfl.interface.interactive_api.experiment import TaskInterface, DataInterface, ModelInterface, FLExperiment

class LinearRegressionDataSet(DataInterface):
    def __init__(self, **kwargs):
        """Initialize DataLoader."""
        self.kwargs = kwargs

    @property
    def shard_descriptor(self):
        """Return shard descriptor."""
        return self._shard_descriptor
    
    @shard_descriptor.setter
    def shard_descriptor(self, shard_descriptor):
        """
        Describe per-collaborator procedures or sharding.
        
        This method will be called during a collaborator initialization.
        Local shard_descriptor  will be set by Envoy.
        """
        self._shard_descriptor = shard_descriptor
        self.train_set = shard_descriptor.get_dataset('train')
        self.val_set = shard_descriptor.get_dataset('val')
        
    def get_train_loader(self, **kwargs):
        """Output of this method will be provided to tasks with optimizer in contract."""
        return self.train_set

    def get_valid_loader(self, **kwargs):
        """Output of this method will be provided to tasks without optimizer in contract."""
        return self.val_set

    def get_train_data_size(self):
        """Information for aggregation."""
        return len(self.train_set)

    def get_valid_data_size(self):
        """Information for aggregation."""
        return len(self.val_set)
    
lin_reg_dataset = LinearRegressionDataSet()

### Define Model Interface

In [None]:
framework_adapter = 'custom_adapter.CustomFrameworkAdapter'

# LinearRegression class accepts a parameter n_features. Should be same as `sample_shape` from `director_config.yaml`
fed_model = LinearRegression(1)
MI = ModelInterface(model=fed_model, optimizer=None, framework_plugin=framework_adapter)

# Save the initial model state
initial_model = LinearRegression(1)

### Register Tasks
We need to employ a trick reporting metrics. OpenFL decides which model is the best based on an *increasing* metric.

In [None]:
TI = TaskInterface()

@TI.add_kwargs(**{'lr': 0.01,
                   'epochs': 101})
@TI.register_fl_task(model='my_model', data_loader='train_data', \
                     device='device', optimizer='optimizer')     
def train(my_model, train_data, optimizer, device, lr, epochs):
    X, Y = train_data[:,:-1], train_data[:,-1]
    my_model.fit(X, Y, epochs, lr, silent=False)
    return {'train_MSE': my_model.mse(X, Y),}

@TI.register_fl_task(model='my_model', data_loader='val_data', device='device')
def validate(my_model, val_data, device):
    X, Y = val_data[:,:-1], val_data[:,-1] 
    return {'validation_MSE': my_model.mse(X, Y),}

### Run the federation

In [None]:
experiment_name = 'jax_linear_regression_experiment'
fl_experiment = FLExperiment(federation=federation, experiment_name=experiment_name)

In [None]:
fl_experiment.start(model_provider=MI,
                    task_keeper=TI,
                    data_loader=lin_reg_dataset,
                    rounds_to_train=2)

In [None]:
fl_experiment.stream_metrics()

# JAX Linear Regression without federation (Optional Simulation)

In [None]:
!pip install matplotlib scikit-learn -q

In [None]:
# Imports for running JAX Linear Regression example without OpenFL.

import matplotlib.pyplot as plt

%matplotlib inline
from matplotlib.pylab import rcParams
rcParams['figure.figsize'] = 7, 5

from jax import make_jaxpr
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split

#### Simple Linear Regression
<img src="https://www.analyticsvidhya.com/wp-content/uploads/2016/01/eq5-1.png" width="500">



In [None]:
# create a dataset with n_features
X, y = make_regression(n_samples=100, n_features=1, noise=10, random_state=42)

# Train test split - Default 0.75/0.25
X, X_test, y, y_test = train_test_split(X, y, random_state=42)

Visualize data distribution

In [None]:
_ = plt.scatter(X, y)

In [None]:
_ = plt.scatter(X_test, y_test)

In [None]:
# JAX logical execution plan
print(jax.make_jaxpr(update)(jnp.ones(X.shape[1]), X, y, 0.01))

In [None]:
# X.shape -> (n_samples, n_features)

lr_model = LinearRegression(X.shape[1])
lr = 0.01
epochs = 101

print(f"Initial Test MSE: {lr_model.mse(X_test,y_test)}")

# silent: logging verbosity
lr_model.fit(X,y, epochs, lr, silent=False)

print(f"Final Test MSE: {lr_model.mse(X_test,y_test)}")

print(f"Final parameters: {lr_model.weights}")