In [None]:
%load_ext autoreload
%autoreload 2

from imports import *
from models import *
from utils import *
from data import *
from configs import CONFIGS, EXP_CODES

In [None]:
'''
load saved experiment
'''


config_name = 'freq_3dim'
config = deepcopy(CONFIGS[config_name])

train_dataset, test_dataset, train_loader, test_loader = get_dataset_dataloader(config)

In [None]:
'''
load model from wandb run
'''
model_code = EXP_CODES[config_name][0]

fix_randomness(seed=42)

model = config['model_class'](seed=0, **config['model_args']).cuda()
model.load_state_dict(torch.load('./models/model_{}.pt'.format(model_code)))
_ = model.eval()


'''
generate random trajectory from test dataset
'''
trajectory, gt_vs = test_dataset.generate_sample_trajectory(length=2000)

i1 = trajectory[:-1, :].unsqueeze(0).cuda()
i2 = trajectory[1:, :].unsqueeze(0).cuda()

with torch.no_grad():
    pred_vs = model_encoder(model=model, first_img=i1, second_img=i2)

    pred_vs = pred_vs.squeeze().cpu()


gt_velocities = np.array(gt_vs)
pred_velocities = np.array(pred_vs)
colors = test_dataset.cmap(gt_vs)

if gt_velocities.squeeze().ndim == 2 and pred_velocities.squeeze().ndim == 2:
    if gt_velocities.shape[1] == 2 and pred_velocities.shape[1] == 3:
        gt_velocities = np.hstack((gt_velocities, np.zeros((gt_velocities.shape[0], 1))))

if gt_velocities.squeeze().ndim == 1 and pred_velocities.squeeze().ndim == 1:
    gt_velocities = np.hstack((gt_velocities.reshape(-1, 1), np.zeros((gt_velocities.shape[0], 1))))
    
    pred_velocities = np.hstack((pred_velocities.reshape(-1, 1), np.zeros((pred_velocities.shape[0], 1))))


plotly_scatter(pred_velocities, colors, '')

In [None]:
pca = PCA()

pca.fit(pred_velocities - pred_velocities.mean(axis=0))

print(pca.explained_variance_ratio_)

total_variance = pca.explained_variance_ratio_[pca.explained_variance_ratio_ != min(pca.explained_variance_ratio_)].sum()

In [None]:
'''
remove outliers
'''

error, transformed_pred_velocities = compute_error_metric(
    true=gt_velocities,
    pred=pred_velocities,
    num_clusters=30,
    show_knee_visualization=True,
)

In [None]:
plotly_scatter(gt_velocities, colors, 'true velocities')

In [None]:
plotly_scatter(
    pred_velocities,
    colors,
    'predicted velocities',
)

In [None]:
plotly_scatter(
    transformed_pred_velocities,
    colors,
    'predicted velocities transformed, error: {:.4f}'.format(error)
)

In [None]:
plt.rcParams.update({'font.size': 22})

plt.figure()

plt.scatter(*gt_velocities.T, c=colors, s=100)

ax = plt.gca()
ax.set_xlim([-0.06, 0.06])
ax.set_ylim([-0.06, 0.06])
ax.set_xlabel('dim 1')
ax.set_ylabel('dim 2')
ax.set_title('1d Frequency Modulation true space')

ax.xaxis.set_tick_params(labelbottom=False)
ax.yaxis.set_tick_params(labelleft=False)
ax.set_xticks([])
ax.set_yticks([])

In [None]:
# plt.rcParams.update({'font.size': 22})

# plt.figure()

# plt.scatter(*pred_velocities.T, c=colors)

# ax = plt.gca()
# ax.set_xlim([-0.09, 0.09])
# ax.set_ylim([-0.09, 0.09])
# ax.set_xlabel('dim 1')
# ax.set_ylabel('dim 2')
# ax.set_title('Our Model Error: {:.4f}'.format(error))

# ax.xaxis.set_tick_params(labelbottom=False)
# ax.yaxis.set_tick_params(labelleft=False)
# ax.set_xticks([])
# ax.set_yticks([])

In [None]:
plt.rcParams.update({'font.size': 22})

plt.figure()

plt.scatter(*pred_velocities.T, c=colors)

ax = plt.gca()
ax.set_xlim([-0.005, 0.001])
ax.set_ylim([-0.005, 0.001])
ax.set_xlabel('dim 1')
ax.set_ylabel('dim 2')
ax.set_title('1 dimension explained \nwith {:.2f} variance'.format(total_variance))

ax.xaxis.set_tick_params(labelbottom=False)
ax.yaxis.set_tick_params(labelleft=False)
ax.set_xticks([])
ax.set_yticks([])

In [None]:
plt.rcParams.update({'font.size': 22})

fig = plt.figure()

ax = fig.add_subplot(projection='3d')

ax.scatter3D(*pred_velocities.T, c=colors)

ax.set_xlim([-0.005, 0.00005])
ax.set_ylim([-0.005, 0.005])
ax.set_zlim([-0.005, 0.005])
ax.set_xlabel('dim 1')
ax.set_ylabel('dim 2')
ax.set_zlabel('dim 3')
ax.set_title('1 dimension explained \nwith {:.2f} variance'.format(total_variance))

ax.xaxis.set_tick_params(labelbottom=False)
ax.yaxis.set_tick_params(labelleft=False)
ax.set_xticks([])
ax.set_yticks([])
ax.set_zticks([])
ax.xaxis.labelpad=-11.5
ax.yaxis.labelpad=-11.5
ax.zaxis.labelpad=-11.5

# ax.view_init(elev=10., azim=250.0)

In [None]:
# ransac = RANSACRegressor()
# ransac.fit(pred_velocities[:, :2], pred_velocities[:, 2])

# # Find inliers and outliers
# inlier_mask = ransac.inlier_mask_
# outlier_mask = np.logical_not(inlier_mask)
# percent_outlier = outlier_mask.sum() / pred_velocities.shape[0]

# plotly_scatter(pred_velocities[inlier_mask], colors=colors[inlier_mask], title='')

# residual = np.linalg.norm(pred_velocities[:, 2] - ransac.predict(pred_velocities[:, :2]))

# print("percent of outliers {:.3f}, residual: {:.3f}".format(percent_outlier, residual))

# plotly_scatter(pred_velocities[inlier_mask], colors=colors[inlier_mask], title='')