In [4]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm.auto import tqdm

import torch
import torch.nn as nn
from torch.utils.data import ConcatDataset
import torchvision
from torchvision import transforms

from dataloading.nvidia import NvidiaTrainDataset, NvidiaSpringTrainDataset, NvidiaAutumnTrainDataset, NvidiaValidationDataset
from network import PilotNet
from trainer import Trainer

import wandb

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [9]:
root_path = Path("/media/romet/data2/datasets/rally-estonia/dataset")

spring_ds = NvidiaSpringTrainDataset(root_path, filter_turns=True)
spring_dl = torch.utils.data.DataLoader(spring_ds, batch_size=64, shuffle=True,
                                         num_workers=60, pin_memory=True, persistent_workers=True)

print(len(spring_ds.frames))

/media/romet/data2/datasets/rally-estonia/dataset/2021-05-20-12-36-10_e2e_sulaoja_20_30: 12025
/media/romet/data2/datasets/rally-estonia/dataset/2021-05-20-12-43-17_e2e_sulaoja_20_30: 6809
/media/romet/data2/datasets/rally-estonia/dataset/2021-05-20-12-51-29_e2e_sulaoja_20_30: 5393
/media/romet/data2/datasets/rally-estonia/dataset/2021-05-20-13-44-06_e2e_sulaoja_10_10: 3833
/media/romet/data2/datasets/rally-estonia/dataset/2021-05-20-13-51-21_e2e_sulaoja_10_10: 3798
/media/romet/data2/datasets/rally-estonia/dataset/2021-05-20-13-59-00_e2e_sulaoja_10_10: 687
/media/romet/data2/datasets/rally-estonia/dataset/2021-05-28-15-07-56_e2e_sulaoja_20_30: 15626
/media/romet/data2/datasets/rally-estonia/dataset/2021-05-28-15-17-19_e2e_sulaoja_20_30: 3218
/media/romet/data2/datasets/rally-estonia/dataset/2021-06-07-14-06-31_e2e_rec_ss6: 3003
/media/romet/data2/datasets/rally-estonia/dataset/2021-06-07-14-09-18_e2e_rec_ss6: 4551
/media/romet/data2/datasets/rally-estonia/dataset/2021-06-07-14-36-16_e

In [10]:
autumn_ds = NvidiaAutumnTrainDataset(root_path, filter_turns=True)
autumn_dl = torch.utils.data.DataLoader(spring_ds, batch_size=64, shuffle=True,
                                         num_workers=60, pin_memory=True, persistent_workers=True)

print(len(autumn_ds.frames))

/media/romet/data2/datasets/rally-estonia/dataset/2021-09-24-11-19-25_e2e_rec_ss10: 34760
/media/romet/data2/datasets/rally-estonia/dataset/2021-09-24-11-40-24_e2e_rec_ss10_2: 16223
/media/romet/data2/datasets/rally-estonia/dataset/2021-09-24-12-02-32_e2e_rec_ss10_3: 8142
/media/romet/data2/datasets/rally-estonia/dataset/2021-09-24-12-21-20_e2e_rec_ss10_backwards: 64975
/media/romet/data2/datasets/rally-estonia/dataset/2021-09-24-13-39-38_e2e_rec_ss11: 33255
/media/romet/data2/datasets/rally-estonia/dataset/2021-09-30-13-57-00_e2e_rec_ss14: 3287
/media/romet/data2/datasets/rally-estonia/dataset/2021-09-30-15-03-37_e2e_ss14_from_half_way: 21755
/media/romet/data2/datasets/rally-estonia/dataset/2021-09-30-15-20-14_e2e_ss14_backwards: 52762
/media/romet/data2/datasets/rally-estonia/dataset/2021-09-30-15-56-59_e2e_ss14_attempt_2: 66899
/media/romet/data2/datasets/rally-estonia/dataset/2021-10-07-11-05-13_e2e_rec_ss3: 54211
/media/romet/data2/datasets/rally-estonia/dataset/2021-10-07-11-44-