In [1]:
import numpy as np
import torch
from pathlib import Path
import os
import sys
sys.path.append('../')

# from orbit_datasets import neworbits, versatileorbits, staticorbits
from ldcl.data import neworbits, versatileorbits, staticorbits

In [2]:
import matplotlib.pyplot as plt
plt.style.use('ggplot')

# https://plotly.com/python/creating-and-updating-figures/
import plotly.graph_objects as go

from mpl_toolkits import mplot3d

In [3]:
from ldcl.data import physics

In [4]:
train_orbits_dataset, folder = physics.get_dataset("../variance_orbit_config.json", "../../saved_datasets")
print(folder)

..\..\saved_datasets\20221213-4\


In [5]:
orbits_loader = torch.utils.data.DataLoader(
    dataset = train_orbits_dataset,
    shuffle = True,
    batch_size = 1,
)

In [8]:
def applyNetworks(locations):
    net_list = [torch.load(x, map_location = torch.device('cpu')) for x in locations]
    for net in net_list:
        net.eval()
    
    encoder_outputs_list = []
    target_values = []

    for it, (input1, input2, y) in enumerate(orbits_loader):
        out = input1.float()
        for f in net_list:
            out = f(out)
        #predicted_representation = branch_encoder(input1.float()).detach().numpy()[0]
        encoder_outputs_list.append(out.detach().numpy()[0])

        #append conserved quantities to the end of the representation for plotting, y = (1, )
        #[2=phi0,3=energy,4=angular_momentum] discard [0=eccentricity, 1=semimajor_axis]
        target_values.append(np.array([y['phi0'].detach().numpy().flatten()[0],
                                       y['H'].detach().numpy().flatten()[0],
                                       y['L'].detach().numpy().flatten()[0]
                                      ]))
    encoder_outputs = np.vstack(encoder_outputs_list)
    target_values = np.vstack(target_values)

    phi0_c_values = target_values[:,0]
    energy_c_values = target_values[:,1]
    angular_momentum_c_values = target_values[:,2]
    
    return encoder_outputs, [phi0_c_values, energy_c_values, angular_momentum_c_values]

In [12]:
ENCODER_PATH = Path("../saved_models/simclr_test1/final_encoder.pt")
print(ENCODER_PATH)
encoder_outputs, conserved_quantities = applyNetworks([ENCODER_PATH,])
annotations = [f'phi_0: {p:.3f}, H: {h:.3f}, L: {l:.3f}' for (p,h,l) in zip(*conserved_quantities)]
print(encoder_outputs.shape)
print(len(annotations))

..\saved_models\simclr_test1\final_encoder.pt
(1000, 3)
1000
