In [1]:
# Import libraries
import torch
import torchvision
import torchvision.transforms as transforms


# Import custom modules and packages
import params.lenet_mnist

### Load the learning parameters

In [2]:
LEARNING_PARAMS = params.lenet_mnist.LEARNING

### Prepare the data

In [6]:
# Compose several transforms together to be applied to data
# (Note that transforms are not applied yet)
transform = transforms.Compose([
    # Modify the size of the images
    transforms.Resize(params.lenet_mnist.IMAGE_SHAPE),
    
    # Convert a PIL Image or numpy.ndarray to tensor
    transforms.ToTensor(),
    
    # Normalize a tensor image with pre-computed mean and standard deviation
    # (based on the data used to train the model(s))
    # (be careful, it only works on torch.*Tensor)
    transforms.Normalize(**params.lenet_mnist.NORMALIZE_PARAMS),
])

# Load the train dataset
train_dataset = torchvision.datasets.MNIST(
    root = '.',
    train = True,
    transform = transform,
    download = True,
)

# Load the test dataset
test_dataset = torchvision.datasets.MNIST(
    root = '.',
    train = False,
    transform = transform,
    download=True,
)

# Combine a dataset and a sampler, and provide an iterable over the dataset
# (setting shuffle argument to True calls a RandomSampler, and avoids to
# have to create a Sampler object)
train_loader = torch.utils.data.DataLoader(
    dataset = train_dataset,
    batch_size = LEARNING_PARAMS['batch_size'],
    shuffle = True,
    num_workers=12,  # Asynchronous data loading and augmentation
    pin_memory=True,  # Increase the transferring speed of the data to the GPU
)

test_loader = torch.utils.data.DataLoader(
    dataset = test_dataset,
    batch_size = LEARNING_PARAMS['batch_size'],
    shuffle = False,  # SequentialSampler
    num_workers=12,
    pin_memory=True,
)

### Device selection

In [7]:
# Use a GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}\n")

Device: cuda



### Training and validation

#### Visualize the loss and accuracy

### Testing