In [None]:
from pathlib import Path
from typing import Tuple, Dict

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.utils.data as data
import torch.nn as nn

from sklearn.preprocessing import StandardScaler

# from models import OneActuatorModel
# from data import OneActuatorDataset

In [None]:
PATH_TO_F1TENTH_GYM = Path('../f1tenth_gym')
BATCH_SIZE = 256
FULL_STATE = False

In [None]:
one_race = pd.read_pickle('./data/together/2021-10-07_01_29_53.472821.pkl')['data']

In [None]:
one_race.head()

In [None]:
class OneActuatorDataset(data.Dataset):
    
    def __init__(self, directory: str, prob_flip: float = 0.5, full_state: bool = False):
        """
        Inputs:
            directory - The directory with .pkl files that are going to be unpacked
            prob_flip - Probability of doing a vertical flip of the data
        """
        super().__init__()

        self.directory = Path(directory)
        self.prob_flip = prob_flip
        self.full_state = full_state
            
        self.size = 0
        self.state_colnames = ['position', 'v_x', 'v_y', 'yaw', 'omega']
        self.actuator_colnames = ['speed_actuator', 'delta']
        self.num_state_coords = len(self.state_colnames)
        
        self.fetch_data()

    def fetch_data(self):
        self.data = []
        self.cumul_sizes = []
        for filename in self.directory.glob('*.pkl'):            
            one_race = pd.read_pickle(filename)
            self.size += len(one_race) - 1
            self.cumul_sizes.append(self.size)
            one_race['v_x'] = one_race['velocity'].apply(lambda x: x[0])
            one_race['v_y'] = one_race['velocity'].apply(lambda x: x[1])
            one_race = one_race[self.state_colnames + self.actuator_colnames].values
            self.data.append(one_race)
            
    def __len__(self) -> int:
        return self.size

    def __getitem__(self, idx) -> Tuple[Dict, Dict]:
        which_race = 0
        idx_shift = 0
        for cum_size in self.cumul_sizes:
            if idx >= cum_size:
                which_race += 1
                idx_shift = cum_size
            else:
                break

        one_race = self.data[which_race]
        idx -= idx_shift

        ith_row = one_race[idx, :]
        i_plus_1th_row = one_race[idx + 1, :]
        position_0, v_x0, v_y0, yaw_0, omega_0 = ith_row[:self.num_state_coords]
        speed_actuator, delta = ith_row[self.num_state_coords:]
        position_1, v_x1, v_y1, yaw_1, omega_1 = i_plus_1th_row[:self.num_state_coords]

        position_diff = np.linalg.norm(position_0 - position_1)
        yaw_diff = yaw_0 - yaw_1
        if yaw_diff > np.pi:
            yaw_diff -= (2 * np.pi)
        elif yaw_diff < -np.pi:
            yaw_diff += (2 * np.pi)
                
        if np.random.uniform() < self.prob_flip:
            yaw_diff = -yaw_diff
            delta = -delta
            if self.full_state:
                v_y0 *= -v_y0
                omega_0 = -omega_0
                omega_1 = -omega_1
        
        state_transition_features = np.r_[position_diff, yaw_diff]
        if self.full_state:
            state_0 = np.r_[v_x0, v_y0, omega_0]
            state_1 = np.r_[v_x1, v_y1, omega_1]
            state_transition_features = np.r_[state_transition_features, state_0, state_1]
            
        return  (
            {'state_transition_features': state_transition_features},
            {'speed_and_delta': np.r_[speed_actuator, delta]}
        )
        

In [None]:
train_dataset = OneActuatorDataset(directory='./data/train', prob_flip=0.5, full_state=FULL_STATE)

In [None]:
train_dataset[100]

In [None]:
train_loader = data.DataLoader(train_dataset, BATCH_SIZE, shuffle=False, num_workers=0)

In [None]:
features, targets = train_dataset[100]
features_scalers = {key: StandardScaler() for key in features.keys()}
targets_scalers = {key: StandardScaler() for key in targets.keys()}

In [None]:
for features_batch, targets_batch in train_loader:

    for feature_name in features.keys():
        features_scalers[feature_name].partial_fit(features_batch[feature_name])

    for target_name in targets.keys():
        targets_scalers[target_name].partial_fit(targets_batch[target_name])

In [None]:
class OneActuatorModel(nn.Module):

    def __init__(self, input_size: int, output_size: int, num_layers: int = 3, num_neurons: int = 30):
        super().__init__()
        
        layers = [
            nn.Linear(input_size, num_neurons),
            nn.SiLU(inplace=True),
        ]
        for _ in range(num_layers - 1):
            layers.append(nn.Linear(num_neurons, num_neurons))
            layers.append(nn.SiLU(inplace=True))

        layers.append(nn.Linear(num_neurons, output_size))

        self.module = nn.Sequential(*layers)

    def forward(self, state_transition_features):
        return {'speed_and_delta': self.module(state_transition_features)}

In [None]:
input_size = sum(len(feat) for _, feat in features.items())
output_size = sum(len(target) for _, target in targets.items())

In [None]:
one_actuator_model = OneActuatorModel(input_size, output_size)

In [None]:
device = 'cuda'
one_actuator_model.to(device)

In [None]:
train_loader = data.DataLoader(train_dataset, BATCH_SIZE, shuffle=True, num_workers=1)

valid_dataset = OneActuatorDataset(directory='./data/valid', prob_flip=0.0, full_state=FULL_STATE)
valid_loader = data.DataLoader(valid_dataset, BATCH_SIZE, shuffle=True, num_workers=1)

In [None]:
num_epochs = 30

mse_loss = nn.MSELoss()

optimizer = torch.optim.Adam(one_actuator_model.parameters(), lr=1e-3)
train_mses_for_plot = []
valid_mses_for_plot = []


for epoch in range(num_epochs):
    
    ############
    # Training #
    ############
    train_total_mse = 0
    one_actuator_model.train()
    for features_batch, targets_batch in train_loader:
        features_batch = {
            feature_name: torch.from_numpy(features_scalers[feature_name].transform(features_batch[feature_name])).float().to(device)
            for feature_name in features_batch.keys()
        }
        targets_batch = {
            target_name: torch.from_numpy(targets_scalers[target_name].transform(targets_batch[target_name])).float().to(device)
            for target_name in targets_batch.keys()
        }
        
        speed_and_delta_preds = one_actuator_model(**features_batch)['speed_and_delta']
        speed_and_delta = targets_batch['speed_and_delta']
        
        loss = mse_loss(speed_and_delta_preds, speed_and_delta)
                
        optimizer.zero_grad()
        
        loss.backward()
        
        optimizer.step()
        
        train_total_mse += float(loss)
        
    avg_train_mse = train_total_mse / len(train_loader)  # TODO: this is not exactly true
    print(f'Avg training MSE@{epoch}: {avg_train_mse:.3f}')
    train_mses_for_plot.append(avg_train_mse)
    
        
    ##############
    # Validation #
    ##############
    valid_total_mse = 0
    one_actuator_model.eval()
    for features_batch, targets_batch in valid_loader:
        features_batch = {
            feature_name: torch.from_numpy(features_scalers[feature_name].transform(features_batch[feature_name])).float().to(device)
            for feature_name in features_batch.keys()
        }
        targets_batch = {
            target_name: torch.from_numpy(targets_scalers[target_name].transform(targets_batch[target_name])).float().to(device)
            for target_name in targets_batch.keys()
        }
        
        with torch.no_grad():
            speed_and_delta_preds = one_actuator_model(**features_batch)['speed_and_delta']
        speed_and_delta = targets_batch['speed_and_delta']
        
        loss = mse_loss(speed_and_delta_preds, speed_and_delta)

        valid_total_mse += float(loss)
                
    avg_valid_mse = valid_total_mse / len(valid_loader)
    print(f'Avg validation MSE@{epoch}: {avg_valid_mse:.3f}\n')
    valid_mses_for_plot.append(avg_valid_mse)

In [None]:
plt.plot(train_mses_for_plot)
plt.plot(valid_mses_for_plot);

In [None]:
speed_and_delta = speed_and_delta.cpu().numpy()
speed_and_delta_preds = speed_and_delta_preds.cpu().numpy()

In [None]:
plt.scatter(speed_and_delta[:, 0], speed_and_delta_preds[:, 0])
plt.gca().set_aspect('equal');

In [None]:
plt.scatter(speed_and_delta[:, 1], speed_and_delta_preds[:, 1])
plt.gca().set_aspect('equal');