In [None]:
import torch
from torch.utils.data import Dataset, DataLoader

import numpy as np
import itertools
import umap
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

import do_not_circ as dnc
import copy
%matplotlib inline

from src.nets import Nets
from src.utils import *
from src.postprocessing.postprocessing import *
from src.postprocessing.stats_plotting import *
from src.postprocessing.interpolation import *

from src.save_load import *


%load_ext autoreload
%autoreload 2


In [None]:
COLOR_CENTER = "gold"
COLOR1 = "thistle" 
COLOR2 = "lightsalmon" 
COLOR3 = "skyblue" 
COLOR4 = "rosybrown"

In [None]:
root_folder = os.environ["PATH_TO_DNC_FOLDER"]
exp = ""
experiment_folder = os.path.join(root_folder, "experiments", exp)



In [None]:
image_folder = os.path.join(".", "images", exp)
if not os.path.exists(image_folder):
    os.mkdir(image_folder)

In [None]:
cfgs = load_configs(experiment_folder)
cfgs.index

In [None]:
exp_id = ""


In [None]:
cache_dict = {}

In [None]:
# Load Data

device = None
cfg = cfgs.loc[exp_id]
cfg["data_meta"]["N"] = 500

data_set = get_data(cfg, device=device)
data_loader = DataLoader(data_set, batch_size=cfgs.loc[exp_id]["batch_size"], shuffle=False)

criterion = torch.nn.MSELoss()


In [None]:

# Center center model 
center_step = 1001
center_idx = 0

center_model = get_all_models(experiment_folder, center_step)[exp_id][str(center_idx)]

cache_dict["center_config"] = {
    "step": center_step,
    "idx": center_idx
}


In [None]:
# Manually get models
model_steps = [1001, 1001, 1001, 1001]
model_idxs = [1, 2, 3, 4]

basis_vectors = [get_params_vec(get_all_models(experiment_folder, model_steps[i])[exp_id][str(model_idxs[i])]) 
                 for i in range(len(model_steps))]

cache_dict["basis_vectors"] = "trained"
cache_dict["basis_vectors_config"] = {
    "steps": model_steps,
    "model_idxs": model_idxs
}

In [None]:
# Get random models
num_parameters = get_model_num_params(center_model)
num_dir = 4
basis_vectors = torch.randn(num_dir, num_parameters)

cache_dict["basis_vectors"] = "random"
cache_dict["basis_vectors_config"] = None

In [None]:
basis_orthonorm_vectors = create_offset_orthonorm_basis_new(center_model, basis_vectors)
basis_orthonorm_vectors = [torch.Tensor(v) for v in basis_orthonorm_vectors]

cache_dict["basis_orthonorm_vectors"] = basis_orthonorm_vectors

In [None]:
c1 = get_coordinates(basis_vectors[0], basis_orthonorm_vectors, get_params_vec(center_model))
c2 = get_coordinates(basis_vectors[1], basis_orthonorm_vectors, get_params_vec(center_model))
c3 = get_coordinates(basis_vectors[2], basis_orthonorm_vectors, get_params_vec(center_model))

cO = [0, 0, 0]

l1 = get_net_loss(vec_to_net(basis_vectors[0], center_model), data_loader, criterion, full_dataset=True)
l2 = get_net_loss(vec_to_net(basis_vectors[1], center_model), data_loader, criterion, full_dataset=True)
l3 = get_net_loss(vec_to_net(basis_vectors[2], center_model), data_loader, criterion, full_dataset=True)
l4 = get_net_loss(vec_to_net(basis_vectors[3], center_model), data_loader, criterion, full_dataset=True)

lO = get_net_loss(center_model, data_loader, criterion, full_dataset=True)

In [None]:
# Get the grid of vectors and the loss values for each grid point
num_inter_models = 20
grid_bound = [-13, 13]
cache_dict["num_inter_models"] = num_inter_models
cache_dict["grid_bound"] = grid_bound

func = lambda m: get_net_loss(m, data_loader, criterion, full_dataset=True, device=None)



grid = get_models_grid(center_model, basis_orthonorm_vectors, num_inter_models, grid_bound)
vals = get_model_interpolate_grid(center_model, basis_orthonorm_vectors, num_inter_models, grid_bound, func)


cache_data(experiment_folder, "UMAP_HighD", vals, meta_dict=cache_dict, time_stamp=True)

In [None]:

# Load cache 
vals, meta_data = load_cached_data(experiment_folder, "UMAP_HighD", time_stamp="")

grid_bound = meta_data["grid_bound"]
num_inter_models = meta_data["num_inter_models"]
basis_orthonorm_vectors = meta_data["basis_orthonorm_vectors"]

grid_arr = np.linspace(grid_bound[0], grid_bound[1], num_inter_models)

grid = get_models_grid(center_model, basis_orthonorm_vectors, num_inter_models, grid_bound)


In [None]:
# Filter the grid by loss

grid_filter = vals.reshape(-1) <  1e-5


In [None]:
# Check how much remains
grid.reshape(np.prod(grid.shape[:-1]), -1)[grid_filter].shape

In [None]:
# Fit umap 
# TODO to get where the original vectors lie, append them at the end, or add them or something with their respective losses
# and mark them with an x
fit = umap.UMAP(n_neighbors=200, min_dist=0.4, metric='euclidean', verbose=True)

# add the basis vectors 
filtered_grid = grid.reshape(np.prod(grid.shape[:-1]), -1)[grid_filter]
filtered_grid = np.concatenate([filtered_grid, [get_params_vec(center_model).detach().numpy()], [b.detach().numpy() for b in basis_vectors]])

u = fit.fit_transform(filtered_grid)


In [None]:
# Merge fit with labels
labels = np.concatenate([vals.reshape(-1)[grid_filter], [lO, l1, l2, l3, l4]])

dataset = pd.DataFrame({'x1': u[:-5, 0], 'x2': u[:-5, 1], 'label': labels[:-5]})

# Plot
sns.set(font_scale=1.4, rc={'figure.figsize':(13, 10)})
sns.set_style("white", {'axes.spines.left': False,
                         'axes.spines.bottom': False,
                         'axes.spines.right': False,
                         'axes.spines.top': False})
ax = sns.scatterplot(x="x1", y="x2", hue="label",
                     data=dataset,
                     s=10,
                     palette='inferno', 
                     edgecolor="none")

ax = sns.scatterplot(x=u[-5:, 0], y=u[-5:, 1],
                     s=300, hue=[COLOR_CENTER, COLOR1, COLOR2, COLOR3, COLOR4], palette=[COLOR_CENTER, COLOR1, COLOR2, COLOR3, COLOR4])

ax.set(yticks=[], xticks=[], xlabel='', ylabel='')
ax.get_legend().remove()

norm = plt.Normalize(0, 1e-5)
# norm = plt.Normalize(dataset['label'].min(), dataset['label'].max())
sm = plt.cm.ScalarMappable(cmap="inferno", norm=norm)
# divider = make_axes_locatable(ax)
# cax = divider.append_axes("right", size="2%", pad=0.5)
cax = inset_axes(ax,
                   width="4%",  # width = 5% of parent_bbox width
                   height="75%",  # height : 50%
                   loc='center left',
                   bbox_to_anchor=(1, 0., 1, 1),
                   bbox_transform=ax.transAxes,
                   borderpad=0,
                   )
cbar = ax.figure.colorbar(sm, cax=cax, )
cbar.outline.set_visible(False)
cbar.ax.tick_params(size=0)
cbar.ax.get_yaxis().labelpad = 5
cbar.ax.tick_params(labelsize=20)
cbar.ax.set_ylabel('J(' + '\u03B8' + ')', rotation=0, size=20)

ax.get_figure().savefig(os.path.join(image_folder, "UMAP_4d.pdf"), bbox_inches = 'tight', pad_inches = 0)



In [None]:
dataset['label'].max()