Question-3

jax.numpy-Implements the NumPy API, using the primitives in jax.lax.

jax.scipy.special.logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False)[source]
Compute the log of the sum of exponentials of input elements

JAX runs transparently on the GPU or TPU (falling back to CPU if you don’t have one). 
However, in the above example, JAX is dispatching kernels to the GPU one operation at a time.
If we have a sequence of operations, we can use the @jit decorator to compile multiple operations together using XLA. Let’s try that.

Taking derivatives with grad()
In addition to evaluating numerical functions, we also want to transform them. 
One transformation is automatic differentiation. 
In JAX, just like in Autograd, you can compute gradients with the grad() function.

In [None]:
def sum_logistic(x):
      return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))

In [None]:
[0.25 0.19661197 0.10499357]

Auto-vectorization with vmap()
JAX has one more transformation in its API that you might find useful: vmap(), the vectorizing map. 
It has the familiar semantics of mapping a function along array axes, but instead of keeping the loop on the outside, 
it pushes the loop down into a function’s primitive operations for better performance. 
When composed with jit(), it can be just as fast as adding the batch dimensions by hand.

We’re going to work with a simple example, and promote matrix-vector products into matrix-matrix products using vmap(). 
Although this is easy to do by hand in this specific case, the same technique can apply to more complicated functions.

In [None]:
mat = random.normal(key, (150, 100))
batched_x = random.normal(key, (10, 100))

def apply_matrix(v):
      return jnp.dot(mat, v)

Of course, vmap() can be arbitrarily composed with jit(), grad(), and any other JAX transformation.
This is just a taste of what JAX can do. We’re really excited to see what you do with it!

All datasets are subclasses of torch.utils.data.Dataset i.e, they have __getitem__ and __len__ methods implemented. Hence, they can all be passed to a torch.utils.data.DataLoader which can load multiple samples in parallel using torch.multiprocessing workers

All the datasets have almost similar API. They all have two common arguments: transform and target_transform to transform the input and target respectively. You can also create your own datasets using the provided base classes.

DATASETS & DATALOADERS
Code for processing data samples can get messy and hard to maintain; we ideally want our dataset code to be decoupled from our model training code for better readability and modularity. PyTorch provides two data primitives: torch.utils.data.DataLoader and torch.utils.data.Dataset that allow you to use pre-loaded datasets as well as your own data. Dataset stores the samples and their corresponding labels, and DataLoader wraps an iterable around the Dataset to enable easy access to the samples.

PyTorch domain libraries provide a number of pre-loaded datasets (such as FashionMNIST) that subclass torch.utils.data.Dataset and implement functions specific to the particular data. They can be used to prototype and benchmark your model. You can find them here: Image Datasets, Text Datasets, and Audio Datasets

Loading a Dataset
Here is an example of how to load the Fashion-MNIST dataset from TorchVision. Fashion-MNIST is a dataset of Zalando’s article images consisting of 60,000 training examples and 10,000 test examples. Each example comprises a 28×28 grayscale image and an associated label from one of 10 classes.

We load the FashionMNIST Dataset with the following parameters:
root is the path where the train/test data is stored,
train specifies training or test dataset,
download=True downloads the data from the internet if it’s not available at root.
transform and target_transform specify the feature and label transformations

In [None]:
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt


training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

jax.random.PRNGKey
jax.random.PRNGKey(seed)[source]
Create a pseudo-random number generator (PRNG) key given an integer seed
Parameters
seed (int) – a 64- or 32-bit integer used as the value of the key.
Return type-Union[Any, PRNGKeyArray]

Often, we want to operate on objects that look like dicts of arrays, or lists of lists of dicts, or other nested structures. In JAX, we refer to these as pytrees, but you can sometimes see them called nests, or just trees.

JAX has built-in support for such objects, both in its library functions as well as through the use of functions from jax.tree_utils (with the most common ones also available as jax.tree_*). This section will explain how to use them, give some useful snippets and point out common gotchas.