In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torchvision
from torchvision import transforms

from dataloading.nvidia import NvidiaResizeAndCrop, Normalize, NvidiaDataset
from network import PilotNet
from trainer import Trainer

import wandb

%load_ext autoreload
%autoreload 2

## Datasets

In [2]:
root_path = Path("/home/romet/data/datasets/ut/nvidia-data")

train_paths = [root_path / "2021-05-20-12-36-10_e2e_sulaoja_20_30",
            root_path / "2021-05-20-12-43-17_e2e_sulaoja_20_30", 
            root_path / "2021-05-20-12-51-29_e2e_sulaoja_20_30",
            root_path / "2021-05-20-13-44-06_e2e_sulaoja_10_10",
            root_path / "2021-05-20-13-51-21_e2e_sulaoja_10_10"]

tr = transforms.Compose([NvidiaResizeAndCrop(), Normalize()])
trainset = NvidiaDataset(train_paths, transform=tr)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True,
                                         num_workers=3, pin_memory=True, persistent_workers=True)

In [3]:
valid_paths = [root_path / "2021-05-20-13-59-00_e2e_sulaoja_10_10"]
validset = NvidiaDataset(valid_paths, transform=tr)
validloader = torch.utils.data.DataLoader(validset, batch_size=64, shuffle=False)

## Model

In [4]:
model_name = "models/1-pilotnet-base/1cam-v2"
model = PilotNet()
criterion = nn.L1Loss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, betas=(0.9, 0.999), 
                              eps=1e-08, weight_decay=0.01, amsgrad=False)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
criterion = criterion.to(device)

## Train

In [5]:
N_EPOCHS = 100
trainer = Trainer(model_name)
trainer.train(model, trainloader, validloader, optimizer, criterion, N_EPOCHS)

[34m[1mwandb[0m: Currently logged in as: [33mrometaidla[0m (use `wandb login --relogin` to force relogin)


  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

Saving best model.


  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

Saving best model.


  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

Saving best model.


  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

Saving best model.


  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

Saving best model.


  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/498 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

Early stopping, on epoch: 27.


## Save models

In [7]:
torch.save(model.state_dict(), f"{model_name}/last.pt")
wandb.save(f"{model_name}/last.pt")
wandb.save(f"{model_name}/best.pt")

['/home/romet/Projects/nvidia-e2e/wandb/run-20210601_195321-x8kck25j/files/models/1-pilotnet-base/1cam-v2/best.pt']

In [8]:
data = iter(validloader).next()
inputs = data['image'].to(device)
best_model = trainer.load_model(f"{model_name}/best.pt")
ONNX_FILE_PATH = f"{model_name}/best.onnx"
torch.onnx.export(best_model, inputs, ONNX_FILE_PATH)
wandb.save(f"{model_name}/best.onnx")

['/home/romet/Projects/nvidia-e2e/wandb/run-20210601_195321-x8kck25j/files/models/1-pilotnet-base/1cam-v2/best.onnx']