# **Road Following - Training Model**

We will train a neural network to take an input image and output a set of x,y values ​​corresponding to a target.

We will use the **PyTorch** deep learning framework to train a **ResNet18** neural network architecture model.

In [None]:
import torch
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms
import glob
import PIL.Image
import os
import numpy as np

## **1. Create a dataset instance**
Customize a class ``torch.utils.data.Dataset`` and implement the ``__len__`` and ``__getitem__`` function methods. This class is responsible for loading images and parsing x, y values ​​from image file names.

In [None]:
class XYDataset(torch.utils.data.Dataset):

    def __init__(self, directory, target_number=1, random_hflips=False):
        self.directory = directory
        self.target_number = target_number
        self.random_hflips = random_hflips
        self.image_paths = glob.glob(os.path.join(self.directory, '*.jpg'))
        self.color_jitter = transforms.ColorJitter(0.3, 0.3, 0.3, 0.3)

    def get_label(self, path):
        label = []
        for i in range(self.target_number):
            offset = 8*i+3
            target_x = (float(int(path[offset: offset+3]))- 50.0) / 50.0
            target_y = (float(int(path[offset+4 : offset+7])) - 50.0) / 50.0
            label.append(target_x)
            label.append(target_y)
            #target_x = float(int(path[3:6])（11：14）(19:22)
            #target_y = float(int(path[7:10]))(15, 18)(23:26)
        return torch.tensor(label).float()


    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]

        image = PIL.Image.open(image_path)

        label = self.get_label(os.path.basename(image_path))
        #x = float(get_x(os.path.basename(image_path)))
        #y = float(get_y(os.path.basename(image_path)))
        '''data enhanced'''

        if self.random_hflips:
            if float(np.random.rand(1)) > 0.5:
                image = transforms.functional.hflip(image)
                x = -x

        image = self.color_jitter(image)
        image = transforms.functional.resize(image,(224, 224))
        image = transforms.functional.to_tensor(image)
        image = image.numpy()[::-1].copy()
        image = torch.from_numpy(image)
        image = transforms.functional.normalize(image, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

        return image, label#torch.tensor([x, y]).float()


In [None]:
dataset = XYDataset('data/dataset_xy', 2, random_hflips=False)

## **1. Create a dataset instance**
Customize a class ``torch.utils.data.Dataset`` and implement the ``__len__`` and ``__getitem__`` function methods. This class is responsible for loading images and parsing x, y values ​​​​from image file names.

In [None]:
test_percent = 0.3
num_test = int(test_percent * len(dataset))
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [len(dataset) - num_test, num_test])

## **3. Create a data loader to load data in batches**
Use ``DataLoader class`` to load data in batches. This container allows multiple subprocesses and shuffle data. The batch size we use is 8.

In [None]:
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=8,
    shuffle=True,
    num_workers=0
)

test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=8,
    shuffle=True,
    num_workers=0
)

## **4. Define the neural network**
We use the ``ResNet-18`` model on **PyTorch TorchVision** and load the pre-trained model for **transfer learning**

In [None]:
target_number = 2

In [None]:
model = models.resnet18(pretrained=True)

The ResNet model has 512 as the fully connected (fc) final layer of ``in_features``, and we will be doing regression training, so ``out_features`` is **2 x target_number**

Finally, train the model on the GPU

In [None]:
model.fc = torch.nn.Linear(512, 2*target_number)
model = model.to(device)

## **5. Define training parameters**

#### 5.1 Define visualization tools

In [None]:
import ipywidgets
'''parameters'''
epochs_widget = ipywidgets.IntText(description='epochs', value=50)
model_path_widget = ipywidgets.Text(description='model path', value='model/road_following_model.pth')
'''training'''
steps_widget = ipywidgets.IntText(description='steps', value=0)
train_progress_widget = ipywidgets.FloatProgress(min=0.0, max=1.0, description='progress')
train_loss_widget = ipywidgets.FloatText(description='train loss')
eval_progress_widget = ipywidgets.FloatProgress(min=0.0, max=1.0, description='progress')
eval_loss_widget = ipywidgets.FloatText(description='eval loss')
save_info_widget = ipywidgets.Text(description='save info')
train_button = ipywidgets.Button(description='START', button_style='warning',layout=ipywidgets.Layout(width='300px', height='28px'))

train_eval_widget = ipywidgets.VBox([
    ipywidgets.Label('-'*31+'parameters'+'-'*31),
    epochs_widget,
    model_path_widget,
    ipywidgets.Label('-'*29+'training'+'-'*29),
    steps_widget,
    train_progress_widget,
    train_loss_widget,
    eval_progress_widget,
    eval_loss_widget,
    save_info_widget,
    ipywidgets.Label('-'*70),
    train_button
])

#### 5.2 Define training parameters

In [None]:
def train_eval(change):
    global model
    NUM_EPOCHS = epochs_widget.value
    MODEL_PATH = model_path_widget.value # save path
    best_loss = 1e9 # loss
    BAST_MODEL_PATH = 'best_' + MODEL_PATH
    LAST_MODEL_PATH = 'last_' + MODEL_PATH

    optimizer = optim.Adam(model.parameters())# optimizer
    for epoch in range(NUM_EPOCHS):
        steps_widget.value = epoch
        '''当前批次开始训练'''
        model.train()
        train_loss = 0.0
        for index,(images, labels) in enumerate(iter(train_loader)):
            images = images.to(device)
            labels = labels.to(device)
            optimizer.zero_grad() # Optimizer gradients cleared
            outputs = model(images)
            loss = F.mse_loss(outputs, labels)# mean square loss
            train_loss += float(loss) # cumulative loss
            loss.backward() # back propagation to calculate reverse gradient
            optimizer.step() # optimizer updates network parameters
            train_progress_widget.value = (index+1)/len(train_loader)# loading
            train_loss_widget.value = loss # display loss
        train_loss /= len(train_loader) # average loss
        train_loss_widget.value = train_loss # display loss

        '''Current batch starts verification'''
        model.eval()
        test_loss = 0.0
        for index, (images, labels) in enumerate(iter(test_loader)):
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            loss = F.mse_loss(outputs, labels)
            test_loss += float(loss)
            '''Display verification results'''
            eval_progress_widget.value = (index+1)/len(test_loader)
            eval_loss_widget.value = loss
        test_loss /= len(test_loader)
        eval_loss_widget.value = test_loss

        print('%f, %f' % (train_loss, test_loss))
        torch.save(model.state_dict(), LAST_MODEL_PATH)
        if test_loss < best_loss:
            torch.save(model.state_dict(), BEST_MODEL_PATH)
            best_loss = test_loss
            save_info_widget.value = 'step %d (train: %.4f, eval: %.4f)'%(epoch, train_loss, test_loss)


In [None]:
train_button.on_click(train_eval)

SAVE ``best_steering_model_xy.pth``

#### 5.3 Start Training

In [None]:
display(train_eval_widget)

---
Parameters

``epochs_widget``: Set the number of training iterations

``save_model_widget``: Set the model save path

---
Training process

``steps_widget``: Display the current iteration id

``train_progress_widget``: Display the progress of the current batch training part

``train_loss_widget``: Display the loss of the current batch training part

``eval_progress_widget``: Display the progress of the current batch training part

``eval_loss_widget``: Display the loss of the current batch training part

``save_info_widget``: Display the steps of the current epochs, total loss and whether to save

---
Training switch

``train_button``