In [None]:
# =============================================================================
# Copyright (c) 2020 NVIDIA. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================

from torch import optim

from nemo.core import DeviceType, NeuralGraph, NeuralModuleFactory, OperationMode
from nemo.utils import logging

from nemo.collections.cv.modules.data_layers import STL10DataLayer
from nemo.collections.cv.modules.losses import NLLLoss
from nemo.collections.cv.modules.non_trainables import NonLinearity, ReshapeTensor
from nemo.collections.cv.modules.trainables import ImageEncoder, FeedForwardNetwork

# WARNING: setting device to CPU to make sure that the notebook will be able to be
# executed on every machine. However, the training on CPU will be extremely slow,
# so it is strongly suggested to set device to DeviceType.GPU instead.
device = DeviceType.CPU

# Create Neural(Module)Factory - use the indicated device.
nf = NeuralModuleFactory(placement=device)

### Tutorial III: Custrom Training

In this third part of the Neural Graphs (NGs) tutorial we will focus on a different example: training of an image classification model with a ResNet-50 backbone on the STL-10 dataset using a custom training loop.

#### This part covers the following:

 * how to create separate graphs for training and evaluation
 * how to move graph between CPU/GPU devices
 * how to parametrize data loaders
 * how to write a custom training loop using relying on graph actions


In [None]:
# Instantiate data layers for training and validation.
dl_train = STL10DataLayer(height=224, width=224, split="train")
dl_valid = STL10DataLayer(height=224, width=224, split="test")

# Instantiate the loss module.
nll_loss = NLLLoss()

# This may take a while, as the dataset has to be downloaded and verified...

In [None]:
# Instantiate modules forming the "model" - use pretrained ResNet-50 as image encoder.
image_encoder = ImageEncoder(model_type="resnet50", pretrained=True, return_feature_maps=True)
reshaper = ReshapeTensor(input_sizes=[-1, 7, 7, 2048], output_sizes=[-1, 100352])
ffn = FeedForwardNetwork(input_size=100352, output_size=10, hidden_sizes=[100, 100], dropout_rate=0.1)
nl = NonLinearity(type="logsoftmax", sizes=[-1, 10])

# Freeze the encoder - to make the training faster.
image_encoder.freeze()

# This also might take some time, as we need to download the pretrained checkpoint...

In [None]:
# Create the "model graph".
with NeuralGraph(operation_mode=OperationMode.both) as stl10_resnet_classifier:
    # Bind the inputs to encoder.
    feat_map = image_encoder(inputs=stl10_resnet_classifier)
    res_img = reshaper(inputs=feat_map)
    logits = ffn(inputs=res_img)
    preds = nl(inputs=logits)
    # Cherry-pick outputs.
    stl10_resnet_classifier.outputs["predictions"] = preds
    
# Ok, let us see what the graph looks like now.
logging.info(stl10_resnet_classifier.summary())

In [None]:
# Let us now compose a training graph...
with NeuralGraph(operation_mode=OperationMode.training) as training_graph:
    # Take outputs from the data layer.
    _, x, t, _ = dl_train()
    # Pass the images to the model.
    p = stl10_resnet_classifier(inputs=x)
    # Calculate the loss.
    lss = nll_loss(predictions=p, targets=t)

# Ok, let us see what the graph looks like now.
logging.info(training_graph.summary())

In [None]:
# ... and a validation graph.
with NeuralGraph(operation_mode=OperationMode.evaluation) as validation_graph:
    # Take outputs from the data layer.
    _, x_valid, t_valid, _ = dl_valid()
    # Pass them to the trainable module.
    p_valid = stl10_resnet_classifier(inputs=x_valid)
    # Calculate the loss.
    loss_valid = nll_loss(predictions=p_valid, targets=t_valid)

# This is how it looks now.
logging.info(validation_graph.summary())

In [None]:
# Perform operations on the indicated device.
training_graph.to(device)
validation_graph.to(device)

# Create the optimizer.
opt = optim.Adam(training_graph.parameters(), lr=0.001)

# Print frequency.
freq = 10

In [None]:
# Finally, construct and run the custom training loop.

# Train for 5 epochs.
for epoch in range(5):
    # Configure data loader used by the training graph - once per epoch.
    # Use default settings - just change the batch_size and turn sample shuffling on.
    training_graph.configure_data_loader(batch_size=64, shuffle=True)

    # Iterate over the whole dataset - in batches.
    for step, batch in enumerate(training_graph.get_batch()):

        # Reset the gradients.
        opt.zero_grad()

        # Forward pass.
        outputs = training_graph.forward(batch)
        # Print loss.
        if step % freq == 0:
            logging.info("Epoch: {} Step: {} Training Loss: {}".format(epoch, step, outputs.loss))

        # Backpropagate the gradients.
        training_graph.backward()

        # Update the parameters.
        opt.step()
    # Epoch ended.

    # Evaluate graph on test set.
    # Configure data loader used by the validation graph - once per epoch.
    valid_losses = []
    validation_graph.configure_data_loader(batch_size=64)
    # Iterate over the whole dataset - in batches.
    for step, batch in enumerate(validation_graph.get_batch()):
        # Forward pass.
        outputs = validation_graph.forward(batch)
        valid_losses.append(outputs.loss)
    # Print avgerage loss.
    logging.info("Epoch: {} Avg. Validation Loss: {}".format(epoch, sum(valid_losses) / len(valid_losses)))
