# Using PyTorch to Set Up MNIST Dataset for Training

After you have successfully installed PyTorch, you can use the sample code to load the MNIST dataset.
You will use this dataset for training your neural networks.

In [None]:
# Import necessary packages

%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import numpy as np
import torch

#import helper

import matplotlib.pyplot as plt

## 1. Load Images

MNIST dataset consists of images of greyscale handwritten digits. Each image is 28x28 pixels, you can see a sample below

<img src='mnist.png'>

Our goal is to build a neural network that can take one of these images and predict the digit in the image.

First up, we need to get our dataset. This is provided through the `torchvision` package. The code below will download the MNIST dataset, then create training and test datasets for us. The dataset's name is MNIST, this code downloads that dataset to a local folder of your choice, in this example, it is ~/.pytorch/MNIST_dadta/.

In [None]:
### Run this cell

from torchvision import datasets, transforms

# Define a transform to normalize the data
transform = transforms.Compose([transforms.ToTensor(),
                              transforms.Normalize((0.5,), (0.5,)),
                              ])

# Download and load the training data
trainset = datasets.MNIST('~/.pytorch/MNIST_data/', download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

## 2. Datasets and Batches

We have the training data loaded into `trainset`, and the dataset is loaded batches (`trainloader`). We make that an iterator with `iter(trainloader)`. Later, we'll use this to loop through the dataset for training, like

```python
for image, label in trainloader:
    ## do things with images and labels
```

You'll notice I created the `trainloader` with a batch size of 64, and `shuffle=True`. The batch size is the number of images we get in one iteration from the data loader and pass through our network, often called a *batch*. And `shuffle=True` tells it to shuffle the dataset every time we start going through the data loader again. But here I'm just grabbing the first batch so we can check out the data. We can see below that `images` is just a tensor with size `(64, 1, 28, 28)`. So there are 64 images per batch, with every image having 1 color channel(monochrom, gray scale), and of size 28x28.

In [None]:
print('Total number of images in the trainset:', len(trainset))
print('Total number of batches:', len(trainloader))
dataiter = iter(trainloader)
images, labels = dataiter.__next__()
print('Data type of images:', type(images))
print('Shape of images:', images.shape)
print('Shape of labels:', labels.shape)

## 3. Visualize the Images

This is what one of the images looks like. `i` can be 0 to 63, play with it to take a look at some other images in this batch.

In [None]:
i = 3
print(labels[i])
#print(images[i])
print(images[i].flatten().shape)
print(torch.flatten(images[i]).shape)
print(images.view(images.shape[0], -1).shape)
plt.imshow(images[i].numpy().squeeze(), cmap='Greys_r');

## 4. More Experiments with DataLoader

DataLoader is the PyTorch API that prepares the traning data (`trainloader`) for us.

Here is more about DataLoader. Try to play with this code, find out:
- How to control the total number of images you want to use for training
- How to configure the batch size
- Any other ways you want to use the DataLoader


In [None]:
for images, labels in trainloader:
    print('images.shape: {}, labels.shape: {}'.format(images.shape, labels.shape))
    break;

# The following is to illustrate the use of dataloader, with batch_size and drop_last
from torch.utils.data import Dataset, DataLoader
class MyDataset(Dataset):
    def __init__(self, size):
        self.x = torch.randn(size, 1)
    
    def __getitem__(self, index):
        return self.x[index]

    def __len__(self):
        return len(self.x)

dataset = MyDataset(1001)
data_loader = DataLoader(dataset, batch_size=10)
len(data_loader)
for batch_idx, data in enumerate(data_loader):
    if (batch_idx>95):
        print ('batch idx{}, batch len {}'.format(batch_idx, len(data)))

print ('Here is if we drop the last one while loading:')

data_loader = DataLoader(dataset, batch_size=10, drop_last=True)
len(data_loader)
for batch_idx, data in enumerate(data_loader):
    if (batch_idx>95):
        print ('batch idx{}, batch len {}'.format(batch_idx, len(data)))
