In [None]:
import numpy as np
import pandas as pd

import matplotlib
import matplotlib.pyplot as plt

from pathlib import Path

import torch
from torch import nn


from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm


BASE_PATH = Path("data/new_functions/2d")
DATA_PATH = Path(BASE_PATH / "point_cloud_samples")
MODEL_PATH = Path(BASE_PATH / "models" / "simple_tanh")


matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

In [None]:
DIM = 2

x1 = np.linspace(-5, 5, 300)
x2 = np.linspace(-5, 5, 300)

X1, X2 = np.meshgrid(x1, x2)
X = np.c_[X1.flatten(), X2.flatten()]

In [None]:
fig, ax = plt.subplots(ncols=5,nrows=1, subplot_kw={"projection": "3d"},figsize=[7.00697, 1.401394])
axes = ax.ravel()
for i in range(0,5):    
    for rep in range(1):
        
        anchor_points  = pd.read_csv(DATA_PATH / f"vertice_sample{i}_d2_rep{rep}_cma_sample.csv")
        anchor_x0  = anchor_points[:]["x0"]
        anchor_x1  = anchor_points[:]["x1"]
        anchor_x2  = anchor_points[:]["y"]

        
        model = nn.Sequential(
                nn.Linear(DIM, 512),
                nn.Tanh(),
                nn.Linear(512, 1),
                nn.Sigmoid()
        )
        model.load_state_dict(torch.load(Path(MODEL_PATH / f"vertice_sample{i}_d2_rep{rep}_cma_sample_NNrep_0_NEPOCHS40000.pt"), map_location=torch.device('cpu')), strict=True)

        
        y = model.forward(torch.Tensor(X)).detach().numpy().reshape(X1.shape)



        axes[i].scatter(anchor_x0, anchor_x1, anchor_x2,color="black",s=5)
        axes[i].plot_surface(X1, X2, y,cmap=cm.viridis)
        axes[i].set_xticks([])
        axes[i].set_yticks([])
        axes[i].set_zticks([])
        axes[i].view_init(elev=25)

        axes[i].set_title(i + 1, y= -0.09, fontsize=10)
fig.tight_layout()
#fig.savefig("FirsRepsOfNewFuncs3dPlots.png", dpi=300)