In [1]:
import torch
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import DataLoader
from torchvision import transforms 
from torch import optim

%load_ext autoreload
%autoreload 3

from model import CarlaNet
from utils import visualizeImages , cost ,loss_mask  , StopEarly  ,load_saved_model
from data import AgentData , ToTensor 


  from ._conv import register_converters as _register_converters


### Set up training environment 

In [2]:
batch_size = 200
BRANCH_LOSS_WEIGHT= [0.95, 0.95, 0.95, 0.95] # how much each branch is weighted when computing loss
# params used in loss, we will update this None values while training 
params = {"branch_weights":BRANCH_LOSS_WEIGHT,\
          # how much each of the outputs specified on TARGETS are weighted for learning.
         "variable_weights":{"Steer":0.5,"Gas":0.45,"Brake":0.05},\
          # values we used them while training 
         "branches":None , "targets": None , "controls_mask":None}
loss_type = "L1"
model = CarlaNet().cuda()
opt = optim.SGD(model.parameters(),lr=0.01,momentum=0.9,nesterov=True)
stop_early = StopEarly()

## prepare Data

In [3]:
transform = transforms.Compose([ToTensor()])
training_data = AgentData('./SeqTrain',transforms=transform)
train_loader = DataLoader(training_data,batch_size=batch_size)
data_iterator = iter(train_loader)

In [4]:
# info about Data 
print(f"num of trainin samples is : {len(training_data)}")

num of trainin samples is : 226000


In [5]:
num_of_batches = len(training_data)/batch_size
num_of_batches

1130.0

## Training

In [6]:
def train_net(n_epochs):
    
    #check if we have saved checkpoint and load it if exist
    load_saved_model(model)
    # prepare the net for training
    model.train()

    for epoch in range(n_epochs):  # loop over the dataset multiple times
        
        running_loss = 0.0
        epoch_loss = 0.0
        
        # train on batches of data, assumes you already have train_loader
        for batch_i, data in enumerate(train_loader):
            # get inputs(img,speed,command) and targets(steer,gas,brake)
            inputs ,targets = data
            imgs , speed , command = inputs
            imgs , speed , command = imgs.type(torch.cuda.FloatTensor) ,speed.type(torch.cuda.FloatTensor)\
                                    ,command.type(torch.cuda.LongTensor)
            
            speed , command = speed.view((batch_size,1)) , command.view((batch_size,1))
            inputs = (imgs , speed , command)
            
            # targets 
            steer , gas , brake = targets
            targets = torch.stack([steer,gas,brake],dim=1).cuda()

            # forward pass to get outputs
            outs = model(inputs)
            controls_masks = loss_mask(command)
            
            #update params
            params["branches"] = outs
            params["targets"] = targets
            params["controls_mask"] = controls_masks

            # calculate the loss between predicted and target keypoints
            loss = cost(params=params,type_loss=loss_type)

            # zero the parameter (weight) gradients
            opt.zero_grad()
            
            # backward pass to calculate the weight gradients
            loss.backward()

            # update the weights
            opt.step()

            # print loss statistics
            # to convert loss into a scalar and add it to the running_loss, use .item()
            running_loss += loss.item()
            epoch_loss += running_loss
            if (batch_i +1) % 100 == 0:    # print every 100 batches
                print('Epoch: {}, Batch: {}, Avg. Loss: {}'.format(epoch + 1, batch_i+1, running_loss/100))
                running_loss = 0.0
        print('----------------\nEpoch loss: {}\n------------\n'.format(epoch_loss/num_of_batches))
        if stop_early(model,epoch_loss/num_of_batches) :
            print("Early stop activated ... stoping training process..")
            break
    print('Finished Training')

In [8]:
train_net(20) # test train with 20 + 3 in the previse session = 23 iter

checkpoint loaded.

Epoch: 1, Batch: 100, Avg. Loss: 0.0018992124107899144
Epoch: 1, Batch: 200, Avg. Loss: 0.0018389033508719877
Epoch: 1, Batch: 300, Avg. Loss: 0.0018614178083953448
Epoch: 1, Batch: 400, Avg. Loss: 0.0022299541140091608
Epoch: 1, Batch: 500, Avg. Loss: 0.0032677691581193356
Epoch: 1, Batch: 600, Avg. Loss: 0.003032795008912217
Epoch: 1, Batch: 700, Avg. Loss: 0.00314891014539171
Epoch: 1, Batch: 800, Avg. Loss: 0.0032294048025505616
Epoch: 1, Batch: 900, Avg. Loss: 0.003066960951546207
Epoch: 1, Batch: 1000, Avg. Loss: 0.0034852452372433618
Epoch: 1, Batch: 1100, Avg. Loss: 0.002393294073990546
----------------
Epoch loss: 0.13279929974016452
------------

Epoch: 2, Batch: 100, Avg. Loss: 0.0016854530840646476
Epoch: 2, Batch: 200, Avg. Loss: 0.0018373145427904092
Epoch: 2, Batch: 300, Avg. Loss: 0.001780009117210284
Epoch: 2, Batch: 400, Avg. Loss: 0.0022529468542779795
Epoch: 2, Batch: 500, Avg. Loss: 0.003226497326104436
Epoch: 2, Batch: 600, Avg. Loss: 0.0029484