In [None]:
import math
import sys

import IPython
import IPython.display as ipd
import matplotlib.pylab as plt
import numpy as np
import pandas as pd

%reload_ext autoreload
%autoreload 2

%matplotlib inline
#%matplotlib notebook

from matplotlib import rcParams
rcParams["figure.max_open_warning"] = False
rcParams["font.family"] = 'DejaVu Sans'
rcParams["font.size"] = 12

In [None]:
from wall_analysis import parse_experiments
from crazyflie_description_py.experiments import WALL_ANGLE_DEG
import seaborn as sns
from simulation import get_df_theory
from plotting_tools import pcolorfast_custom
from plotting_tools import save_fig

linestyles = {"estimated": "-", "theo": ":", "theo_corr": "-."}
colors = {"estimated": "C0", "theo": "black", "theo_corr": "black"}
keys = {"estimated": "estimated", "theo": "theo, vertical", "theo_corr": "theoretical"}

MIN_Z_CM = 40

def plot_plositions(row, min_time=None, max_time=None, max_dist=None):
    positions_cm = row.positions[:, :3] * 100
    fig, axs = plt.subplots(1, 3) 
    fig.set_size_inches(10, 3.3)
    fig.suptitle(row.appendix, y=1.0)
    
    mask_time = np.ones_like(row.seconds, dtype=np.bool)
    if min_time is not None:
        mask_time = (row.seconds > min_time) 
    else:
        mask_time = positions_cm[:, 2] > MIN_Z_CM
    if max_time is not None:
        mask_time = mask_time & (row.seconds < max_time)
    else:
        mask_time = positions_cm[:, 2] > MIN_Z_CM
        
    time = row.seconds[mask_time]
    
    #axs[0].plot(x=positions_cm[:, 0], y=positions_cm[:, 1], color=colors())
    sns.scatterplot(x=positions_cm[mask_time, 0], y=positions_cm[mask_time, 1], 
                    hue=time, ax=axs[0], linewidth=0, 
                    #size=positions_cm[:, 2],
                    palette='inferno')
    axs[0].set_xlabel('x [cm]')
    axs[0].set_ylabel('y [cm]')
    axs[0].axis('equal')
    axs[0].legend(loc='lower right', title='time [s]')

    axs[1].plot(time, positions_cm[mask_time, 0], label='x')
    axs[1].plot(time, positions_cm[mask_time, 1], label='y')
    axs[1].plot(time, positions_cm[mask_time, 2], label='z')
    axs[1].set_xlabel('time [s]')
    axs[1].set_ylabel('movement [cm]')
    if max_dist is not None:
        axs[1].set_ylim(-max_dist, max_dist)
    axs[1].legend(loc='lower right')

    axs[2].plot(time, row.positions[mask_time, 3], label='yaw')
    axs[2].set_ylabel('yaw [deg]')
    axs[2].set_xlabel('time [s]')
    axs[2].set_ylim(-20, 20)
    axs[2].legend(loc='lower right')
    plt.tight_layout()
    return fig, axs

def plot_audio(row, mic_idx=0):
    all_frequencies = row.freqs
    spec = row.spectrogram[:, mic_idx, :]
    spec[spec == 0] = np.nan
    total = np.nanmean(np.abs(spec), axis=1)
    
    label = str(f"{row.appendix}").replace('_', '')
    fig, ax = plt.subplots()
    fig.set_size_inches(10, 5)
    ax.set_title(f'spectrogram of mic{mic_idx}, appendix {row.appendix}')
    
    # mark too long measurements gray
    max_diff = 1
    diff = row.seconds[1:] - row.seconds[:-1]
    indices = np.where(diff>max_diff)[0]
    endings = row.seconds[:-1][indices]
    diff_average = np.mean(diff[diff<max_diff])
    seconds = row.seconds
    for counter, i in enumerate(indices):
        new_time = seconds[i+counter]+diff_average
        seconds = np.insert(seconds, i+counter+1, new_time)
        spec = np.insert(spec, i+counter+1, np.nan, axis=1)
    
    pcolorfast_custom(ax, seconds, all_frequencies, np.abs(spec))
    
    xticks = np.arange(0, row.seconds.max(), step=5)
    ax.set_xticks(xticks); ax.set_xticklabels(xticks)
    yticks = np.arange((np.round(row.freqs.min()//1000)+1)*1000, 
                        row.freqs.max(), step=1000)
    ax.set_yticks(yticks); ax.set_yticklabels(yticks)
    ax.set_ylabel('frequency [Hz]')
    ax.set_xlabel('seconds [s]')
    return fig, ax

def plot_df(distance_range, freq_range=[2000, 6000], azimuth_deg=WALL_ANGLE_DEG, mic_idx=0):
    distances_grid = np.linspace(distance_range[0], distance_range[1],100)
    freqs_theo = np.linspace(freq_range[0], freq_range[1], 200)
    df_matrix_theo = get_df_theory(freqs_theo, distances_grid, azimuth_deg=azimuth_deg, 
                                   chosen_mics=[mic_idx])
    fig_df, ax_df = plt.subplots()
    fig_df.set_size_inches(5, 3)
    xticks = [7] + list(np.arange(10, distance_range[1]+1, step=10))
    yticks = np.arange(freq_range[0], freq_range[1]+1, step=1000)
    pcolorfast_custom(
        ax_df, distances_grid, freqs_theo, np.log10(df_matrix_theo[0]), cmap='Greys',
        alpha=0.5,
    )
    ax_df.set_xticks(xticks); ax_df.set_xticklabels(xticks)
    ax_df.set_yticks(yticks); ax_df.set_yticklabels(yticks)
    ax_df.set_xlabel('distance [cm]')
    ax_df.set_ylabel('frequency [Hz]')
    return fig_df, ax_df

def change_order(axs_all, mean_distances, xlabel="distance", unit="cm", ylabel="probability [-]", title=True):
    sorted_idx = np.argsort(mean_distances)
    positions = [ax.get_position() for ax in axs_all]
    for i, idx in enumerate(sorted_idx):
        axs_all[idx].set_position(positions[i])
        if title:
            axs_all[idx].set_title(f"{mean_distances[idx]:.0f}{unit}")
    [ax.get_yaxis().set_visible(False) for ax in axs_all[sorted_idx[1:]]]
    axs_all[sorted_idx[len(sorted_idx)//2]].set_xlabel(f"{xlabel} [{unit}]")
    axs_all[sorted_idx[0]].set_ylabel(ylabel)
    axs_all[sorted_idx[-1]].legend(loc="upper left", bbox_to_anchor=[1.0, 1.0])
    return axs_all

# 1. Frequency slice

In [None]:
#exp_name = '2020_12_18_stepper'; appendix = ""; distance = 51
#exp_name = '2020_12_18_flying'; appendix="_new"; distance = 0
#exp_name = '2021_03_01_flying';
#exp_name = '2021_04_30_hover';
exp_name = '2021_05_04_flying';
#exp_name = '2021_07_14_flying_hover';
fname = f'../experiments/{exp_name}/all_data.pkl'

try:
    df_total = pd.read_pickle(fname)
    print('read', fname)
except:
    answer = input('Run wall_analysis.py to parse experiments? (y/[n])') or 'n'
    if answer == 'y':
        df_total = parse_experiments(exp_name)
        pd.to_pickle(df_total, fname)
        print('saved', fname)

In [None]:
df_total.sort_values(by='appendix', inplace=True)

## 1.1 positions analysis

In [None]:
from geometry import Context
context = Context.get_crazyflie_setup(dim=2)
fig, ax = plt.subplots()
context.plot(ax=ax)
fig.set_size_inches(3, 3)

In [None]:
min_time = None #4 #None
max_time = None #14 #None
max_dist = None

#starting_distance = 65.63 # 42+29.7−6.07 
starting_distance = 100

fig_total, ax_total = plt.subplots()
fig_total.set_size_inches(3, 3)
for i, row in df_total.iterrows():
    fig, axs = plot_plositions(row, min_time, max_time, max_dist)
    
    x = row.positions[:, 0] * 100
    y = row.positions[:, 1] * 100 - starting_distance
    z = row.positions[:, 2] * 100
    
    x = x[~np.isnan(x)] 
    y = y[~np.isnan(y)]
    
    x = x[z[~np.isnan(z)] > MIN_Z_CM]
    y = y[z[~np.isnan(z)] > MIN_Z_CM]
    
    #ax_total.scatter(x, y, s=10.0, label=row.appendix)
    ax_total.plot(x, y, marker='o', label=row.appendix.replace('_', ''))
    #mask2 = (y > -130) & (y < -100)
    #mask3 = (y > -100) & (y < -70)
    #mask4 = (y > -70)
    #for mask in [mask2, mask3, mask4]:
    #    ax_total.scatter(x[mask], y[mask], s=10.0)
    #save_fig(fig, f'plots/experiments/{exp_name}{row.appendix}_movement', extension='.png')
ax_total.axis('equal')
ax_total.set_xlabel('x [cm]')
ax_total.set_ylabel('y [cm]')
ax_total.axhline(0, color='k', label='wall')
ax_total.legend(bbox_to_anchor=[1.0, 1.0], loc='upper left')
#save_fig(fig_total, f'plots/experiments/{exp_name}_pos.png')

## 1.2 audio analysis

In [None]:
from frequency_analysis import add_spectrogram

df_total = df_total.assign(spectrogram=None,freqs=None)
df_total = df_total.apply(add_spectrogram, axis=1)
print(df_total.columns)

mic_idx = 0

#maxi = np.nanmax(np.concatenate([*dfs.spectrogram], axis=1))
for i_col, row in df_total.iterrows():
    
    #complicated spectrogram
    fig, ax = plot_audio(row, mic_idx=mic_idx)
    #ax.set_ylim([2800, 5000])
    #save_fig(fig, f'plots/experiments/{exp_name}{row.appendix}_spec')
    
    #if not row.appendix in ["_bin5_thirdtry", "_bin6"]:
    #    continue
    #fig, ax = plt.subplots()
    #ax.pcolorfast(row.seconds, row.freqs, np.log10(np.abs(row.spectrogram[:-1, mic_idx, :-1])))
    #ax.set_title(row.appendix)

## 1.3 algorithm performance, snr < 5

In [None]:
from calibration import get_calibration_function_median, get_calibration_function_dict
from inference import Inference

fig, ax = plt.subplots()
fig.set_size_inches(10, 5)
calib_function, calib_freq = get_calibration_function_median(
    "2021_04_30_stepper", "audio_deck", ax=ax, snr=3, motors=0#fit_one_gain=True 
)
inf_machine = Inference()
inf_machine.add_calibration_function(calib_function)

In [None]:
from dataset_parameters import kwargs_datasets
from crazyflie_description_py.experiments import WALL_ANGLE_DEG

kwargs = kwargs_datasets[exp_name]["audio_deck"]
azimuth_deg = WALL_ANGLE_DEG

distance_range = [7, 50]
freq_range = [3000, 5000]

fig_df, ax_df = plot_df(distance_range)

inf_machine.add_geometry(distance_range, azimuth_deg)

In [None]:
# spec_masked, freqs_masked = data_collector.fill_from_row()
from copy import deepcopy
from data_collector import DataCollector
from estimators import DistanceEstimator, get_estimate
from inference import eps_normalize
from simulation import get_freq_slice_theory
from itertools import cycle

plot_raw = False
eps = 1e-5  # for plotting only
max_plot_distance = 100
algorithm = "bayes"
# algorithm = "cost"
normalize = True  # normalize probas before combining
method = "sum"  # method used to combine

nominal_ds = [60, 40, 20]

freqs_calib = np.linspace(
    np.min(row.frequencies_matrix), np.max(row.frequencies_matrix), 50
)
f_calib_all = calib_function(freqs_calib)

row = df_total.loc[df_total.appendix == "_22"].iloc[0]

for chosen_mics in [None]:  # , [0], [2], [1, 3], [0, 1], [0, 3], [0, 1, 3]]:
    nominal_distances = cycle(nominal_ds)

    plot_idx = 0
    mean_distances = []

    data_collector = DataCollector(exp_name=exp_name)

    flying_time_indices = np.where(row.positions[:, 2] * 1e2 > MIN_Z_CM)[0]

    count = 0
    sweep_complete = False

    fig_df, ax_df = plot_df(distance_range)
    fig_all, axs_all = plt.subplots(1, len(flying_time_indices))
    fig_res, axs_res = plt.subplots(1, len(flying_time_indices), sharex=True)
    fig_all.set_size_inches(15, 3)
    fig_res.set_size_inches(15, 2)

    for i in flying_time_indices:
        signals_f = row.stft[i]
        frequencies = row.frequencies_matrix[i]

        if sweep_complete:
            print("treating new frequency slice after", count)
            count = 0

            nominal_distance = next(nominal_distances)
            (
                f_slice,
                freqs,
                stds,
                distances,
            ) = data_collector.get_current_frequency_slice(verbose=False)

            rel_distances = starting_distance - distances - nominal_distance
            inf_machine.add_data(
                deepcopy(f_slice), freqs, stds, deepcopy(rel_distances)
            )
            inf_machine.calibrate()

            if plot_raw:
                fig, axs = plt.subplots(3, row.stft.shape[1])  # , sharey='row')
                fig.set_size_inches(15, 10)
                fig.suptitle(f"up to {data_collector.latest_fslice_time:.1f}s", y=0.9)
            # fig.set_suptitle(f"experiment {row.appendix}")

            # raw data, for plotting only
            freqs = inf_machine.values[inf_machine.valid_idx]
            f_slice = f_slice[:, inf_machine.valid_idx]
            distances = distances[inf_machine.valid_idx]

            distance_corr = starting_distance - distances
            mean_distance = np.nanmean(distance_corr)
            rel_distances = distance_corr - nominal_distance  # relative movement

            if mean_distance < max_plot_distance:
                ax_df.scatter(distance_corr, freqs, color=f"C{plot_idx}")
                ax_df.axvline(
                    mean_distance, color="black", ls=":", label="mean distance"
                )

            distance_estimators = {
                "measured": DistanceEstimator(),
                "theo": DistanceEstimator(),
                "theo_corr": DistanceEstimator(),
            }

            if plot_raw:
                for i_mic in range(f_slice.shape[0]):
                    axs[0, i_mic].plot(freqs, f_slice[i_mic], label="measured")
                    axs[0, i_mic].legend(loc="upper right")
                    axs[0, i_mic].set_title(f"mic{i_mic}")

            # treat measured data
            for i_mic in range(f_slice.shape[0]):
                dists, proba, diff = inf_machine.do_inference(
                    mic_idx=i_mic, algorithm=algorithm, normalize=normalize
                )
                distance_estimators["measured"].add_distribution(
                    diff * 1e-2, proba, i_mic
                )
                if plot_raw:
                    inf_machine.plot(
                        i_mic,
                        ax=axs[1, i_mic],
                        label="measured",
                        color=f"C{plot_idx}",
                        ls=linestyles["measured"],
                        standardize=True,
                    )
                    axs[2, i_mic].plot(
                        dists,
                        eps_normalize(proba, eps),
                        label=keys["measured"],
                        ls=linestyles["measured"],
                        color=f"C{plot_idx}",
                    )

            # treat theorectical data
            f_theo = get_freq_slice_theory(
                freqs, distance_cm=mean_distance, azimuth_deg=azimuth_deg
            ).T  # n_mics x n_freqs
            inf_machine.add_data(f_theo, freqs)
            for i_mic in range(f_theo.shape[0]):
                dists_theo, proba_theo, diff_theo = inf_machine.do_inference(
                    mic_idx=i_mic,
                    algorithm=algorithm,
                    normalize=normalize,
                    calibrate=False,
                )
                distance_estimators["theo"].add_distribution(
                    diff_theo * 1e-2, proba_theo, i_mic
                )
                if plot_raw:
                    inf_machine.plot(
                        i_mic,
                        ax=axs[1, i_mic],
                        label=keys["theo"],
                        color=colors["theo"],
                        ls=linestyles["theo"],
                        standardize=True,
                    )
                    axs[2, i_mic].plot(
                        dists_theo,
                        eps_normalize(proba_theo, eps=eps),
                        label=keys["theo"],
                        ls=linestyles["theo"],
                        color=colors["theo"],
                    )

            f_theo_corr = get_freq_slice_theory(
                freqs, distance_cm=distance_corr, azimuth_deg=azimuth_deg
            ).T
            inf_machine.add_data(f_theo_corr, freqs, distances=rel_distances)
            for i_mic in range(f_theo_corr.shape[0]):
                (
                    dists_theo_corr,
                    proba_theo_corr,
                    diff_theo_corr,
                ) = inf_machine.do_inference(
                    mic_idx=i_mic,
                    algorithm=algorithm,
                    normalize=normalize,
                    calibrate=False,
                )
                distance_estimators["theo_corr"].add_distribution(
                    diff_theo_corr * 1e-2,
                    proba_theo_corr - np.mean(proba_theo_corr),
                    i_mic,
                )

                if plot_raw:
                    inf_machine.plot(
                        i_mic,
                        ax=axs[1, i_mic],
                        label=keys["theo_corr"],
                        color=colors["theo_corr"],
                        ls=linestyles["theo_corr"],
                        standardize=True,
                    )
                    axs[2, i_mic].plot(
                        dists_theo_corr,
                        eps_normalize(proba_theo_corr, eps),
                        label=keys["theo_corr"],
                        ls=linestyles["theo_corr"],
                        color=colors["theo_corr"],
                    )
                    axs[2, i_mic].set_yscale("log")

            for key, distance_estimator in distance_estimators.items():
                if key == "theo":
                    continue
                (
                    distance_total,
                    proba_total,
                ) = distance_estimator.get_distance_distribution(
                    method=method, chosen_mics=chosen_mics, azimuth_deg=azimuth_deg
                )

                if mean_distance < max_plot_distance:
                    axs_all[plot_idx].plot(
                        distance_total * 1e2,
                        eps_normalize(proba_total, eps),
                        label=keys[key],
                        ls=linestyles[key],
                        color=f"C{plot_idx}",
                    )

                    d = get_estimate(distance_total * 1e2, proba_total)
                    axs_res[plot_idx].axvline(
                        d, label=keys[key], ls=linestyles[key], color=f"C{plot_idx}",
                    )

            if mean_distance < max_plot_distance:
                axs_all[plot_idx].axvline(
                    mean_distance, color="k", ls=":", label="mean distance"
                )
                axs_all[plot_idx].set_yscale("log")

                axs_res[plot_idx].axvline(
                    mean_distance, color="k", ls=":", label="mean distance"
                )
                axs_res[plot_idx].scatter(distance_corr, freqs, color=f"C{plot_idx}")
                axs_res[plot_idx].set_ylim(3000, 5000)

                mean_distances.append(mean_distance)
                plot_idx += 1

            if plot_raw:
                axs[1, i_mic].legend(loc="upper right")
                axs[2, i_mic].legend(loc="upper right")
                for i_mic in range(f_calib_all.shape[0]):
                    axs[0, i_mic].plot(
                        freqs_calib, f_calib_all[i_mic], label=f"calib", ls=":"
                    )
        if i == row.stft.shape[0] - 1:
            sweep_complete = True
            print("done")
        elif i == 0:
            sweep_complete = False
        else:
            if row.bin_selection < 5:
                sweep_complete = data_collector.next_fslice_ready(
                    signals_f, frequencies, verbose=False
                )
            else:
                sweep_complete = True
                    
        mode = "maximum" if row.bin_selection < 5 else "all"
        data_collector.fill_from_signal(
            signals_f,
            frequencies,
            distance_cm=row.positions[i, 1] * 1e2,
            time=row.seconds[i],
            mode=mode,
        )
        count += 1


# sort subplots according to distance
print(mean_distances)
axs_all = change_order(axs_all, mean_distances)
axs_res = change_order(axs_res, mean_distances, ylabel="frequency [Hz]", title=False)

save_fig(fig_df, f"plots/experiments/{exp_name}_df.png")  # , extension="png")
save_fig(
    fig_all, f"plots/experiments/{exp_name}_mics{chosen_mics}_all.png"
)  # , extension="png")
save_fig(
    fig_res, f"plots/experiments/{exp_name}_mics{chosen_mics}_res.png"
)  # , extension="png")

# 2. Distance slice

In [None]:
#exp_name = '2020_12_18_stepper'; appendix = ""; distance = 51
#exp_name = '2020_12_18_flying'; appendix="_new"; distance = 0
#exp_name = '2021_03_01_flying';
#exp_name = '2021_04_30_hover';
exp_name = '2021_05_04_linear';
fname = f'../experiments/{exp_name}/all_data.pkl'

try:
    df_total = pd.read_pickle(fname)
    print('read', fname)
except Exception as e:
    answer = input('Run wall_analysis.py to parse experiments? (y/[n])') or 'n'
    if answer == 'y':
        df_total = parse_experiments(exp_name)
        pd.to_pickle(df_total, fname)
        print('saved', fname)

## 2.1 positions analysis

In [None]:
starting_distance = 65.63 # 42+29.7−6.07 
starting_positions = {
    '_1': [20, -starting_distance, 0, 45],
    '_2': [-20, -starting_distance, 0, -45],
    '_3': [0, -starting_distance, 0, 0],
    '_4': [-10, -starting_distance, 0, -30],
    '_5': [10, -starting_distance, 0, 30],
    '_fast1': [-20, -starting_distance, 0, -45],
    '_fast2': [20, -starting_distance, 0, 45],
    '_fast3': [0, -starting_distance, 0, 0],
    '_fast4': [20, -starting_distance, 0, 30],
    '_fast5': [-20, -starting_distance, 0, -30]
}

def get_average_angle(positions_rot):
    # angles between -180, 180:
    angles = np.arctan2(positions_rot[:, 1]-positions_rot[0, 1], positions_rot[:, 0]-positions_rot[0, 0]) * 180 / np.pi
    
    # convert 120 to 60 etc.
    #angles[angles > 90] = 180 - angles[angles > 90]
    return np.median(angles)

def get_gt_angle(row):
    starting_yaw = starting_positions[row.appendix][3]
    approach_angle = starting_yaw + 90
    return approach_angle

def get_corrected_positions(appendix, positions):
    starting_pose = starting_positions[appendix]
    positions_rot = np.empty_like(positions)
    for j, pos in enumerate(positions):
        total_yaw = starting_pose[3] + pos[3]
        rot = R.from_euler('z', total_yaw, degrees=True)
        pos_rot = starting_pose[:3] + rot.apply(pos[:3]) * 1e2
        positions_rot[j, :] = np.r_[pos_rot, total_yaw]
    valid = np.all(~np.isnan(positions_rot), axis=1) & (positions_rot[:, 2] > 35)
    positions_rot = positions_rot[valid, :]
    return positions_rot

def plot_corrected_positions(row, max_idx=30, ax=None, **kwargs):
    if ax is None:
        fig, ax = plt.subplots()
        
    positions_rot = get_corrected_positions(row.appendix, row.positions)
    average_angle = get_average_angle(positions_rot)
    gt_angle = get_gt_angle(row)
    
    ax.plot(positions_rot[:max_idx, 0], positions_rot[:max_idx, 1], 
            label=f'experiment{row.appendix}, {average_angle:.0f} {gt_angle:.0f}', **kwargs)
    ax.axis('equal')
    return positions_rot

In [None]:
from scipy.spatial.transform import Rotation as R

max_idx = 40
fig_df, ax_df = plot_df(distance_range=[7, 70])

fig, ax = plt.subplots()
fig.set_size_inches(3, 3)
for i, row in df_total.iterrows():
    fig, axs = plot_plositions(row, min_time=None, max_time=None, max_dist=None)
    plot_corrected_positions(row, ax=ax, max_idx=max_idx)
    
    positions_corr = get_corrected_positions(row.appendix, row.positions)
    ds = -positions_corr[:max_idx, 1]
    ax_df.scatter(ds, np.full(len(ds), 3000))
ax.legend(bbox_to_anchor=[1.0, 1.0], loc='upper left')

## 1.2 audio analysis

In [None]:
from frequency_analysis import add_spectrogram
from plotting_tools import pcolorfast_custom

df_total = df_total.assign(spectrogram=None)
df_total = df_total.apply(add_spectrogram, axis=1)

mic_idx = 0
#maxi = np.nanmax(np.concatenate([*dfs.spectrogram], axis=1))
for i_col, row in df_total.iterrows():
    continue
    #fig, ax = plot_audio(row, mic_idx=mic_idx)
    
    plt.figure()
    for time_idx in range(row.frequencies_matrix.shape[0]):
        mic_idx = 0
        freqs = row.frequencies_matrix[time_idx, :]
        response = np.abs(row.spectrogram[:, mic_idx, time_idx])
        plt.plot(freqs, response)

## 1.3 algorithm analysis

In [None]:
from inference import get_approach_angle_fft
from data_collector import DataCollector
from estimators import AngleEstimator, get_estimate
from copy import deepcopy

plot_raw = False #True

mics = [0, 1, 3]
n_mics = len(mics)
#n_mics = 4
appendices = [f"_{i}" for i in range(1, 6)]#, "_fast2", "_fast4"]

n_cols = len(appendices)
n_rows = 3

start = 0
fig_all, axs_all = plt.subplots(n_rows, n_cols, sharey=True) # 10

[ax.set_xticks([]) for ax in axs_all.flatten()]
[ax.set_yticks([]) for ax in axs_all.flatten()]
axs_all[-1, n_cols//2].set_xlabel('angle [$^\\circ$]')
axs_all[n_rows//2, 0].set_ylabel('probability')
fig_all.set_size_inches(1.5*n_cols, 1.5*n_rows)

fig_res, axs_res = plt.subplots(n_rows, n_cols, sharey=True) # 10
[ax.set_xticks([]) for ax in axs_res.flatten()]
[ax.set_yticks([]) for ax in axs_res.flatten()]
axs_res[-1, n_cols//2].set_xlabel('x [cm]')
axs_res[n_rows//2, 0].set_ylabel('y [cm]')
fig_res.set_size_inches(1.5*n_cols, 1.5*n_rows)
    
#fig_df, ax_df = plot_df(distance_range=[7, 70], freq_range=[2000, 4000])

mean_angles = []
skip_first = 5


for k, (i_row, row) in enumerate(df_total[df_total.appendix.isin(appendices)].iterrows()):
        
    slice_idx = 0
    fig_pos, ax_pos = plt.subplots()
    fig_pos.set_size_inches(3, 3)
    ax_pos.set_xlabel('x [cm]')
    ax_pos.set_ylabel('y [cm]')
    ax_pos.axis('equal')
        
    azimuth_deg = starting_positions[row.appendix][3]
    gt_angle = get_gt_angle(row)
        
    start_i = 0
    
    angle_estimators = {
        'estimated': AngleEstimator(),
    }
    data_collector = DataCollector(exp_name=exp_name)
    
    for i in range(skip_first, row.stft.shape[0]):
        signals_f = row.stft[i]
        frequencies = row.frequencies_matrix[i]
        position = row.positions[i]

        if i == row.stft.shape[0] - 1:
            d_slice_ready = True
            print('reached end')
        else:
            d_slice_ready = data_collector.next_dslice_ready(
                signals_f, frequencies, position*1e2, n_max=50)

        if d_slice_ready:
            d_slices, distances, stds, freqs = data_collector.get_current_distance_slice()
            if len(distances) < 10:
                print(f'skipping last measurement, cause only {len(distances)}')
                break
                
            freq = np.mean(freqs)
                
            print(f'positions from {start_i} to {i}')
            positions = row.positions[start_i:i, :]
            
            positions_corr = get_corrected_positions(row.appendix, positions)
            ds = -positions_corr[:, 1]
            #ax_df.scatter(ds, np.full(len(ds), freq), color=f'C{slice_idx}')
            
            mean_angle = get_average_angle(positions_corr)
            mean_angles.append(mean_angle)
            
            ax_pos.scatter(positions_corr[:, 0], 
                           positions_corr[:, 1], 
                           color=f'C{slice_idx}',
                           s=10)
            ax_pos.set_title(f'experiment {k}')
            
            positions_here = positions_corr - positions_corr[0, :]
            
            axs_res[slice_idx, k].scatter(positions_here[:, 0], 
                positions_here[:, 1], 
                color=f'C{slice_idx}',
                s=10, 
                label="trajectory"
            )
            axs_res[0, k].set_title(f'exp. {k}')
            axs_all[0, k].set_title(f'exp. {k}')

            if plot_raw:
                fig, axs = plt.subplots(2, n_mics, sharey='row')
                fig.set_size_inches(10, 7)
            
            plot_mean_angle = mean_angle if mean_angle < 90 else 180 - mean_angle
            axs_all[slice_idx, k].axvline(plot_mean_angle, color='k', ls=':', label='mean angle')
                
            for i_mic, mic_idx in enumerate(mics):
                d_slice = d_slices[mic_idx]
                if plot_raw:
                    axs[0, i_mic].set_title(f'mic{mic_idx}')
                    axs[0, i_mic].scatter(distances, d_slice, color='C0')
                    axs[0, i_mic].set_xlabel(f"relative distance [cm]")
                
                valid = ~np.isnan(d_slice)
                angles, proba = get_approach_angle_fft(
                    d_slice=d_slice,
                    frequency=freq,
                    relative_distances_cm=distances,
                    reduced=True,
                    bayes=True
                )
                angle_estimators['estimated'].add_distribution(angles, proba, mic_idx, freq)
                
                if plot_raw:
                    axs[1, i_mic].plot(angles, proba, color='C1')
                    axs[1, i_mic].axvline(mean_angle, color='k', ls=':')
            if plot_raw:
                axs[0, 0].set_ylabel(f"magnitude{row.appendix}")
                axs[1, 0].set_ylabel(f"probability{row.appendix}")
                
            l=20
            axs_res[slice_idx, k].plot(
                [0, l*np.cos(mean_angle/180*np.pi)],
                [0, l*np.sin(mean_angle/180*np.pi)],
                color='k',
                ls=':',
                label='mean angle'
            )
            for key, angle_estimator in angle_estimators.items():
                angles, probs = angle_estimator.get_angle_distribution(mics_left_right=[[1], [3]])
                
                est_angle = get_estimate(angles, probs)
                
                angles_plot = deepcopy(angles)
                angles_plot[angles_plot > 90] = 180 - angles[angles_plot > 90]
                axs_all[slice_idx, k].plot(angles_plot, probs, 
                                        color=f'C{slice_idx}', 
                                        label=keys[key], 
                                        ls=linestyles[key])
                #axs_all[slice_idx].set_yscale('log')

                axs_res[slice_idx, k].plot(
                    [0, -l*np.cos(est_angle/180*np.pi)], 
                    [0, l*np.sin(est_angle/180*np.pi)],
                    color=f'C{slice_idx}',
                    ls=linestyles[key],
                    label=keys[key]
                )
                axs_res[slice_idx, k].axis('equal')
            
            slice_idx += 1
            #if plot_raw:
                #save_fig(fig, f'plots/experiments/{exp_name}{row.appendix}_slice{slice_idx}.png', extension='png')
            start_i = i

        # only add measurements if the drone is really flying.
        if data_collector.valid_dslice_measurement(position*1e2, signals_f, frequencies):
            data_collector.fill_from_signal(
                signals_f, frequencies, distance_cm=position[1]*1e2, time=row.seconds[i]
            )
    ax_pos.axhline(0, color='k')
    save_fig(fig_pos, f'plots/experiments/{exp_name}_pos{i_row}.png')
            
save_fig(fig_all, f'plots/experiments/{exp_name}_all.png')
save_fig(fig_res, f'plots/experiments/{exp_name}_res.png')
save_fig(fig_df,  f'plots/experiments/{exp_name}_df.png')