Testing of ResNet+MLP where we take the ResNet output, concatenate it with a gravity vector (which the drone has via its IMU) and pass it through an MLP to get an estimate of the pose.

We try out a bunch of hidden layer configurations and also compare results between networks with pre-trained ResNet weights and networks where the ResNet weights are allowed to be trained on given the data.

In [1]:
# Some useful settings for interactive work
%load_ext autoreload
%autoreload 2

In [2]:
import data_functions as df
import training_functions as tf
from models import *

In [3]:
# Initialize NN variables
hidden_sizes = [
    [256, 128, 7],
    [128, 32, 7],
    [128, 32, 8, 7]
    ]

mlp0 = [VisionPoseMLP(hidden_size,True) for hidden_size in hidden_sizes]
mlp1 = [VisionPoseMLP(hidden_size,False) for hidden_size in hidden_sizes]
mlps:list[VisionPoseMLP] = mlp0 + mlp1

Neps = 300

In [4]:
# Generate data loaders
train_loader, test_loader = df.get_data(0.8)

In [6]:
# Train the basic models
for idx,mlp in enumerate(mlps):
    mlp_name = "basic"+str(idx+1).zfill(3)
    print("=============================================================")
    print("Training: "+mlp_name)
    tf.train_model(mlp, train_loader, mlp_name,useNeRF=False, Neps=Neps)
    print("-------------------------------------------------------------")
    tf.test_model(mlp, test_loader)

Training: basic001
Epoch: 100 | Loss: 0.02
Epoch: 200 | Loss: 0.01
Epoch: 300 | Loss: 0.01
Epoch: 300 | Loss: 0.01
-------------------------------------------------------------
-------------------------------------------------------------
Test Loss: 0.1822
Examples:
-------------------------------------------------------------
Output: [ 2.036  1.933  1.530 -0.229  0.590  0.745 -0.209]
Target: [ 1.361  1.642  1.458 -0.383  0.537  0.592 -0.462]
-------------------------------------------------------------
-------------------------------------------------------------
Output: [-0.977  0.465  1.551  0.159  0.010  0.752  0.640]
Target: [-2.725  0.833  1.729  0.527 -0.496 -0.464  0.511]
-------------------------------------------------------------
-------------------------------------------------------------
Output: [ 2.846 -0.610  1.014  0.540  0.148  0.247  0.791]
Target: [ 1.961 -0.425  0.964  0.481  0.309  0.443  0.690]
-------------------------------------------------------------
Trainin