# Motivation

This notebook will be used to document and test my idea of using structure vectors as learned descriptors that are then passed through a shallow neural network. This idea was motivated by two main facts:

1) I believe that my current models are being limited by the fact that they lack the "mixing" abilities of NN-style potentials. In NNPs, each descriptor component is a sum over the local environment of an atom; those components are then mixed with each other by the NN. In my models this mixing is impossible because  my model is essentially a single-layer network.

2) None of the major MLIPs have learned descriptors, at least as far as I can tell. Based on my understanding, NNP (and its variants), GAP, MTP, and SNAP all define their descriptors using expansions with various basis functions, with the difference between each model being which basis functions they use and which distance kernels they use. SchNet would be an example of a model with learned descriptors (since they use CNN-like filters), but reportedly suffers from significant speed issues.

# Open questions
* How are eng and force SVs used at the same time?

# TO DO
* Normalize data
* Try fitting SVs first
* Try using CMA-ES
* Try smaller SVs, but more of them
* **Pin ends of radial functions**


# Theory notes

## Why are my SVTrees just single-layer networks?
Though this comment had been made by a reviewer on my s-MEAM work, I hadn't fully understood it until now. I came to this realization when I spent more time understanding how the descriptors were built for NNPs ([the ANI paper](https://pubs.rsc.org/en/content/articlepdf/2017/sc/c6sc05720a) and [its supplementary information](http://www.rsc.org/suppdata/c6/sc/c6sc05720a/c6sc05720a1.pdf) were particularly useful). The individual components of the NNP descriptors are constructed by choosing a basis function (e.g. a shifted gaussian) and summing its values over pair distances or angles within a single atomic environment. The full descriptors are then the concatenation of multiple of these components for each "bond type". These descriptors are then fed through a neural network (using a different network for each "host atom" type, which is exactly what I do). In the limit where the NN has only a single layer, each component would is simply summed, which is exactly the case with my trees. There are some minor technical differences, like the fact that my trees can essentially have different activation functions for each node (by wrapping individual SVs in different embedding functions) and that my SVs are learnable and will most likely be more complex than any single basis function (e.g. it could take up to $k$ RBFs to approximate a spline with $k$ knots), but the statement is still essentially true: **my trees are single-layer NN potentials.**

## What is the idea that I'm proposing?
In order to improve the accuracy/flexibility of my models, my proposition is to use the outputs of multiple SVs as the inputs to a shallow neural network. For example, I could define an atomic fingerprint for atom $i$ of element type $X$ as the length-2 vector $$\vec{G}^X_i(S) = \langle \rho(S)_i \text{, } \text{FFG}(S)_i \rangle $$ where $\rho(S)_i$ and $\text{FFG}(S)_i$ are the 2- and 3-body structure vectors of the local environment around atom $i$. The network would then have an input dimension of 2 (since 2 SVs were used), and an output dimension of 1 (for the atomic energy). This would allow the type of "component mixing" that I believe is responsible for the high accuracy of NNP-like models. The number of 2- and 3-body SVs to use would be a hyper-parameter choice, but based on past experience I think 2-3 of each would prove to be sufficient.

Note that although this use of SVs is quite different from what we are used to thinking of them as, it is very similar to how descriptors for NNP-like models are currently generated. The evaluation of an SV no longer represents an atomic energy -- now it represents something like an atomic density. Instead of using a specific number of manually-constructed basis functions to construct the components of the descriptors, as is the case with NNPs, a single SV can be used. For example, instead of using 7 different gaussian basis functions shifted by 7 different atomic distances, I could use a single SV with 7 knots. The idea here is that by fitting a spline as the embedding function we're able to learn the descriptors instead of constructing them manually.

## Would this actually be useful?

TLDR: I don't know, but it seems worth exploring! It would be faster than NNPs, it might be more interpretable, it would probably be at least as accurate, and we already have fast fitting code for it.

* **Speed:** MLIPs notoriously have the problem that constructing the descriptors is extremely expensive. For example, in the AL_Al ANI network I'm pretty sure that they're using 32 radial functions and 64 angular functions. This would make it around 60x slower than a single-element s-MEAM -- this approximation appears to be a bit high though, since the NNPs from the mlearn work only seem to be ~10x slower. But still, that's not a small number. I would also plan to use much fewer descriptors (2-3 each of radial/angular?), which would make the NN itself much smaller; this speed is probably negligible though.

* **Interpretability:** it seems possible, but in no way guaranteed, that an SV would be more interpretable than the descriptors used for a NNP. In the ANI supplementary material they compare the descriptors for two structures by simply plotting all of the components. While this would conceivably still be somewhat interpretable, it would require a lot of back-and-forth checking of what each component represents. Plotting the SVs themselves wouldn't give the descriptors, but it would give insight into how the embedding of local environments is being done. This is theoretically useful, but from my experience I think it's likely that only a few of the splines would be interpretable, while others might be too "wiggly" to understand. But still, maybe better than looking at the NNP embeddings? Also, you'd be able to enforce certain physical behaviors by pinning knots (e.g. forcing convergence to 0 energy at far distances).

* **Accuracy:** I think this is the most difficult attribute to appraise. I don't see any obvious reason why a learned descriptor would be any more/less accurate than a manually-constructed descriptor using a large number of basis functions. Furthermore, I still haven't figured out how to incorporate long-range interactions into an SV. However, the use of the NN to mix the outputs of multiple SVs would almost certainly improve the accuracy of the tree-based models.

* **Fitting:** we would be able to leverage all of the advantages of fitting with SVs that we've already been using.

## Are you sure that other models aren't learning their descriptors?
No, but I'm pretty sure?

* **NNP:** as I said before, NNPs and their variants define their descriptor components by choosing a certain number of shifted gaussians, then concatenating the sums of those gaussians over local environments. The learning is only on the NN weights and biases. While the descriptors could technically be considered "learnable" using hyper-parameter tuning methods, this is rarely (if ever) done.

* **GAP:** the SOAP descriptors used by GAP are kernel functions that output distance metrics that operate on atomic densities that have been written as Bessel function expansions. Yes, there's much more complicated math going on there, but I think this is still essentially true. These kernel functions produce distance metrics that are used in gaussian process regression. The only learning that is done is during the GPR.

* **MTP:** MTPs are just doing linear regression on contracted moment tensors that were constructed using basis functions that are expansions of "moment polynomials". The strength of this method is in the fact that MTP can use increasingly many moment tensor. Their paper says that the basis functions can theoretically span all permutation- and rotation- invariant polynomials. This method seems the most competetive with my idea; by expressing the energy of an atom as a linear combination of these basis functions, it seems that while it isn't necessarily learning descriptors, the descriptors that it's using are complex enough to do nearly anything. I need to understand MTP a bit better, I think; what are the drawbacks of MTP, if any?

* **SNAP:** Nope; I think this is like MTP, but with worse descriptors.

# Implementation

## Technical notes
* If I store the parameter sets, including those dotted with the SVs, as PyTorch tensors with `requires_grad=True`, I think they can handle the back-propogation on their own.
* This formalism can use the versions of the force SVs that are contracted along the neighbor dimension

## Imports

In [None]:
import os
import gc
import time
import glob
import h5py
import random
import pickle
import itertools
import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import CubicSpline

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter

## Dataset class

In [None]:
class SVDataset(Dataset):
    def __init__(self, path, elements, svNames, refStruct):
        super(SVDataset, self).__init__()
        
        self.refStruct = refStruct
        
        with h5py.File(path, 'r') as db:
            
            self.structNames = sorted(list(db.keys()))
            self.refStruct = self.structNames[0]
            self.refIdx = self.structNames.index(self.refStruct)
            
            self.elements = elements
            self.svNames = svNames

            self.all_true_eng = []
            self.all_true_fcs = []
            self.natoms = []

            self.svs = {}
            for svn in svNames:
                self.svs[svn] = {}
                for el in elements:
                    self.svs[svn][el] = {}
                    self.svs[svn][el]['energy'] = []
                    self.svs[svn][el]['forces'] = []

            for i, k in enumerate(self.structNames):
                print(i, k)
                self.all_true_eng.append(db[k].attrs['energy'])
                self.all_true_fcs.append(db[k].attrs['forces'].ravel())
                self.natoms.append(db[k].attrs['natoms'])

                for svn in svNames:
                    for el in elements:
                        self.svs[svn][el]['energy'].append(db[k][svn][el]['energy'][()])
                        self.svs[svn][el]['forces'].append(db[k][svn][el]['forces'][()])

            self._len = len(self.all_true_eng)
            
    def __len__(self):
        return self._len
    
    def __getitem__(self, idx):
        svs_eng = {
            'rho_A': rho_svs_e,
            'ffg_AA': ffg_svs_e
        }

        svs_fcs = {
            'rho_A': rho_svs_f,
            'ffg_AA': ffg_svs_f
        }
        
        return {
            'svs_e': {svn: {el: self.svs[svn][el]['energy'][idx] for el in self.elements} for svn in self.svNames},
            'svs_f': {svn: {el: self.svs[svn][el]['forces'][idx] for el in self.elements} for svn in self.svNames},
            'natoms': self.natoms[idx],
            'energy': self.all_true_eng[idx],
            'forces': self.all_true_fcs[idx],
        }

## Embedder and Model classes

In [None]:
class Embedder(nn.Module):
    def __init__(self, embeddings, device):
        """
        Args:
            embeddings (dict):
                Key = embedder name
                Value = 3-tuple
                    First: list of dimensions of embedder components
                    Second: component orders for building outer products
                    Third: number of duplicates to use
                
                Example:
                    embeddings = {
                        'rho_A': ([9], [0], 2),  # Builds 2 Rho
                        'ffg_AA': ([9, 9], [0, 0, 1], 3)  # Builds 3 F*F*G
                    }
        """
        super(Embedder, self).__init__()
        
        self.embeddings = embeddings
        self._embedding_order = sorted(list(self.embeddings.keys()))
        
        # Register embedding parameters so that they get learned
        input_dim = 0
        for key, (size_list, build_order, duplicates) in self.embeddings.items():
            for dup in range(duplicates):
                # Build parameter vectors
#                 params = [nn.Parameter(torch.rand(nk)) for nk in size_list]
                params = []
                for idx, nk in enumerate(size_list):
#                     if ('rho' in key) or (('ffg' in key) and (idx != len(self.embeddings[key][0])-1)):
                    if False:
                        # Radial splines
                        params.append(nn.Parameter(torch.rand(nk-2)))
                    else:
                        params.append(nn.Parameter(torch.rand(nk)))
                        
                input_dim += 1

                # Record each parameter set (with unique key) for autograd
                for i, pset in enumerate(params):
                    bigKey = key + '_' + str(dup) + '_' + str(i)
                    setattr(self, bigKey, pset)
    
        # Used for resuming training plots
        self.steps = 0
        
        self._zero = torch.zeros(1).view((1,)).to(device)
        self._device = device
        
    def _fill(self, key, dup, idx):
        name = key + '_' + str(dup) + '_' + str(idx)
        pset = getattr(self, name)

#         if ('rho' in key) or (('ffg' in key) and (idx != len(self.embeddings[key][0])-1)):
        if False:
            # Radial splines need to have pinned RHS knot and derivative
            tmp = torch.cat(
                [pset[:-1],
                 self._zero,
                 pset[-1].view((1,)),
                 self._zero]
            ).to(self._device)
        else:
            tmp = pset
        
        return tmp, name
        
        
    def _conv(self, key, dup):
        """
        Helper function to compute the outer products of parameter sets
        
        Args:
            key (str):
                Name of embedder to build
            dup (int):
                Which duplicate to work with
        """

        build_order = self.embeddings[key][1]
        
        cart = None
        for idx in build_order:
            pset, _ = self._fill(key, dup, idx)

            if cart is None:
                cart = pset
            else:
                cart = torch.outer(cart, pset)
                cart = cart.view(cart.shape[0]*pset.shape[0])
        
        return cart.to(device)
    
    def embed(self, svs_fcs):#, svs_fcs):
        """
        Args:
            svs_eng (dict):
                Key = embedder name (must match `embeddings` keys from __init__)
                Value = tensor of shape (N_tot, nk_cart), where N_tot is the
                    total number of atomic environments and nk_cart is the size
                    of the structure vector for the given embedder.
                    
            svs_fcs (dict):
                Same as `svs_eng`, but the values have the shape (N_tot, 3, nk_cart)

        """
        
        # Build the atomic embeddings
#         descriptors_eng = []
        descriptors_fcs = []
        for k in self._embedding_order:
            for dup in range(self.embeddings[k][2]):
                cart = self._conv(k, dup)
#                 descriptors_eng.append(torch.matmul(svs_eng[k], cart).unsqueeze(-1))
                descriptors_fcs.append(torch.matmul(svs_fcs[k], cart).unsqueeze(-1))
            
#         descriptors_eng = torch.cat(descriptors_eng, dim=1)
        descriptors_fcs = torch.cat(descriptors_fcs, dim=2)
        
        descriptors_fcs = descriptors_fcs.view((descriptors_fcs.shape[0]*3, descriptors_fcs.shape[-1]))
        
        return descriptors_fcs#, descriptors_fcs

In [None]:
class Model(nn.Module):
    def __init__(self, input_dim, hidden_dim1=3, hidden_dim2=3):

        super(Model, self).__init__()
        
        self.linear1 = nn.Linear(input_dim, hidden_dim1)
        self.linear2 = nn.Linear(hidden_dim1, hidden_dim1)
        self.linear3 = nn.Linear(hidden_dim1, hidden_dim2)
        self.linear4 = nn.Linear(hidden_dim2, 1)
        
        self.bnorm1 = nn.BatchNorm1d(num_features=hidden_dim1)
        self.splus  = nn.Softplus()
        self.celu   = nn.CELU()
        
        # Used for resuming training plots
        self.steps = 0
        
    def _forward(self, x):
        """
        Helper function for passing the embeddings through the network
        """
        out = self.linear1(x)
        out = self.celu(out)
        out = self.bnorm1(out)
        out = self.linear2(out)
        out = self.celu(out)
        out = self.linear3(out)
        out = self.celu(out)
        out = self.linear4(out)
        
        return out
    
#     def forward(self, descriptors_eng, descriptors_fcs, splits):
    def forward(self, descriptors_fcs, splits):
        """
        Args:
            splits (list):
                N_s integers corresponding to the number of atoms in each of the N_s
                structures being evaluated. Used for splitting intermediate results.
        """

        
        # Evaluate the model
#         atomic_eng = self._forward(descriptors_eng)
#         eng = torch.cat([eatom.sum().view(1)/n for n, eatom in zip(splits, torch.split(atomic_eng, splits))])
        
        fcs = self._forward(descriptors_fcs).squeeze()

        return fcs
#         return eng#, fcs

## Load data

In [None]:
path = os.path.join('/home', 'vita', 'AL_Al-7knots-full-allsums.hdf5')
dataset = SVDataset(path, ['Al'], ['rho_A', 'ffg_AA'], 'AL-step_14-data-003.001.012.000.000.005.h5-0')

In [None]:
max_fcs = np.array([np.max(abs(f)) for f in dataset.all_true_fcs])

In [None]:
_ = plt.hist(max_fcs, bins=100)

In [None]:
badIndices = np.where(max_fcs > 30)[0]
badIndices

## Initialize model

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
nrho = 4
nffg = 8
hidden_dim1 = 128
hidden_dim2 = 32

embeddings = {
    'rho_A': ([9], [0], nrho),
    'ffg_AA': ([9, 9], [0, 0, 1], nffg)
}

embedder = Embedder(embeddings, device)
model = Model(nrho+nffg, hidden_dim1=hidden_dim1, hidden_dim2=hidden_dim2)

In [None]:
baseName = 'AL_Al'
savePath = 'runs/{}_r_{}_f_{}_h1_{}_h2_{}_f_6352_unpin'.format(baseName, nrho, nffg, hidden_dim1, hidden_dim2)
writer = SummaryWriter(savePath)

In [None]:
f = lambda s: int(os.path.split(s)[-1].split('.')[0])
sorted(glob.glob(os.path.join(savePath, '*.model')), key=f)

In [None]:
ii = 910
embedder.load_state_dict(torch.load(os.path.join(savePath, '{}.embedder'.format(ii))))
embedder.eval()

model.load_state_dict(torch.load(os.path.join(savePath, '{}.model'.format(ii))))
model.eval()

In [None]:
model.to(device)
embedder.to(device)

In [None]:
model

In [None]:
loss_fxn = nn.MSELoss()

# optimizer_embed = torch.optim.SGD(embedder.parameters(), lr=1e-5, momentum=0.9)
# optimizer_model = torch.optim.SGD(model.parameters(), lr=1e-5, momentum=0.9)

optimizer_embed = torch.optim.Adam(embedder.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-08)
optimizer_model = torch.optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-08)

lambda_embed = lambda epoch: 0.90**(epoch//10000)
lambda_model = lambda epoch: 0.75**(epoch//10000)

scheduler_embed =  torch.optim.lr_scheduler.MultiStepLR(optimizer_embed, milestones=[100000], gamma=0.1)
scheduler_model =  torch.optim.lr_scheduler.MultiStepLR(optimizer_model, milestones=[100000], gamma=0.1)

### Helper functions

In [None]:
# Define plotting functions

def plot_pred_vs_true(eng, fcs, all_true_eng, all_true_fcs):
    fig, ax = plt.subplots(1, 2, figsize=(12, 6))
#     fig, ax = plt.subplots(figsize=(6, 6))

    # Energy errors figure
    xl, xh = min(all_true_eng), max(all_true_eng)

    rmse = np.sqrt(np.average((eng - all_true_eng)**2))
    ax[0].plot(all_true_eng, eng, 'o', markersize=1, label='RMSE: {:.3f} eV/atom'.format(rmse), alpha=1)
    ax[0].plot([xl, xh], [xl, xh], '--r')

    lgnd = ax[0].legend(loc = 'upper left')
    lgnd.legendHandles[0]._legmarker.set_markersize(6)
    lgnd.legendHandles[0]._legmarker.set_alpha(1)

    ax[0].set_xlim([xl, xh])
    ax[0].set_ylim([xl, xh])

    ax[0].set_aspect('equal')

    ax[0].set_xlabel('True', fontsize=12)
    ax[0].set_ylabel("Predicted", fontsize=12)

    _ = ax[0].set_title('Energies')
    
    # Forces errors figure
    xl, xh = min(all_true_fcs), max(all_true_fcs)

    rmse = np.sqrt(np.average((fcs - all_true_fcs)**2))
    ax[1].plot(all_true_fcs, fcs, 'o', markersize=1, label='RMSE: {:.3f} eV/A'.format(rmse), alpha=0.05)
    ax[1].plot([xl, xh], [xl, xh], '--r')

    lgnd = ax[1].legend(loc = 'upper left')
    lgnd.legendHandles[0]._legmarker.set_markersize(6)
    lgnd.legendHandles[0]._legmarker.set_alpha(1)

    ax[1].set_xlim([xl, xh])
    ax[1].set_ylim([xl, xh])

    ax[1].set_aspect('equal')

    ax[1].set_xlabel('True', fontsize=12)
    ax[1].set_ylabel("Predicted", fontsize=12)

    _ = ax[1].set_title('Forces')
    
#     # Forces errors figure
#     xl, xh = min(all_true_fcs), max(all_true_fcs)

#     rmse = np.sqrt(np.average((fcs - all_true_fcs)**2))
#     ax.plot(all_true_fcs, fcs, 'o', markersize=1, label='RMSE: {:.3f} eV/A'.format(rmse), alpha=0.05)
#     ax.plot([xl, xh], [xl, xh], '--r')

#     lgnd = ax.legend(loc = 'upper left')
#     lgnd.legendHandles[0]._legmarker.set_markersize(6)
#     lgnd.legendHandles[0]._legmarker.set_alpha(1)

#     ax.set_xlim([xl, xh])
#     ax.set_ylim([xl, xh])

#     ax.set_aspect('equal')

#     ax.set_xlabel('True', fontsize=12)
#     ax.set_ylabel("Predicted", fontsize=12)

#     _ = ax.set_title('Forces')
#     plt.tight_layout()
    
    return fig

# @@ Cell 12
def plot_splines(splitParams, names):
    numRows = int(max(1, np.ceil(len(splitParams)/3)))

    fig, axes = plt.subplots(numRows, 3, figsize=(12, 4*numRows))

    for i, (spline, name) in enumerate(zip(splitParams, names)):
        y, bc = spline[:-2], spline[-2:]

    #     if 'g' in compNames[i]:
    #         x = np.linspace(-1, 1, len(y))
    #     else:
    #         x = np.linspace(2.5, 7.0, len(y))

        x = np.linspace(2.5, 7.0, len(y))
        
        cs = CubicSpline(x, y, bc_type=((1, bc[0]), (1, bc[1])))
        cs = CubicSpline(x, y, bc_type='natural')

        plotX = np.linspace(x[0]-.1, x[-1]+.1, 100)
        plotY = cs(plotX)
        
        row = i//3
        col = i%3

        if numRows > 1:
            ax = axes[row][col]
        else:
            ax = axes[col]

        ax.plot(x, y, 'o')
        ax.plot(plotX, plotY)
        ax.set_title(name)
    
    return fig

## Begin training

In [None]:
ns = len(dataset)

epochs = 1000000
batch_size = 1024

indices = np.arange(ns)
start = time.time()
for t in range(epochs):
    # Build random batches
    random.shuffle(indices)
    batches = np.array_split(indices, ns//batch_size)
    
    epoch_eng = []
    epoch_fcs = []
    epoch_true_eng = []
    epoch_true_fcs = []

    # Evaluate all batches
    epoch_loss = 0
    for batch in batches:
        batch = batch.tolist()
        
        for bi in badIndices:
            if bi in batch:
                batch.remove(bi)

        if dataset.refIdx not in batch:
            batch.append(dataset.refIdx)

#         svs_eng = {
#             'rho_A': torch.Tensor(np.vstack([dataset.svs['rho_A']['Al']['energy'][i] for i in batch])).to(device),
#             'ffg_AA': torch.Tensor(np.vstack([dataset.svs['ffg_AA']['Al']['energy'][i] for i in batch])).to(device),
#         }

        svs_fcs = {
            'rho_A': torch.Tensor(np.concatenate([dataset.svs['rho_A']['Al']['forces'][i] for i in batch])).to(device),
            'ffg_AA': torch.Tensor(np.concatenate([dataset.svs['ffg_AA']['Al']['forces'][i] for i in batch])).to(device),
        }

#         all_true_eng = torch.Tensor([dataset.all_true_eng[i] for i in batch]).to(device)
        all_true_fcs = torch.Tensor(np.concatenate([dataset.all_true_fcs[i] for i in batch])).to(device)
        splits = [dataset.natoms[i] for i in batch]

        # Compute loss
#         descriptors_eng, descriptors_fcs = embedder.embed(svs_eng, svs_fcs)
        descriptors_fcs = embedder.embed(svs_fcs)
#         eng, fcs = model(descriptors_eng, descriptors_fcs, splits)
        fcs = model(descriptors_fcs, splits)
#         eng = eng - eng[batch.index(dataset.refIdx)]
        
#         epoch_eng.append(np.array(eng.cpu().detach()))
        epoch_fcs.append(np.array(fcs.cpu().detach()))
#         epoch_true_eng.append(np.array(all_true_eng.cpu().detach()))
        epoch_true_fcs.append(np.array(all_true_fcs.cpu().detach()))

#         loss_eng = loss_fxn(eng, all_true_eng)*10
        loss_fcs = loss_fxn(fcs, all_true_fcs)

#         loss = loss_eng + loss_fcs
        loss = loss_fcs

        # Backpropogate errors
        optimizer_model.zero_grad()
        optimizer_embed.zero_grad()
        loss.backward(retain_graph=True)
        optimizer_embed.step()
        optimizer_model.step()
        
        scheduler_embed.step()
        scheduler_model.step()

        batch_loss = loss.detach().item()
        epoch_loss += batch_loss*(len(batch)/ns)  # weighted average
        
#         recheck_loss = np.average(
#             (np.array(fcs.cpu().detach()) - np.array(all_true_fcs.cpu().detach()))**2
#         )
        
#         print("Batch loss: {:.6f} -- {:.2f} (s)".format(batch_loss, recheck_loss))
        
    # Print status
    print("Epoch {} loss (avg. batch loss): {:.6f} -- {:.2f} (s)".format(t, epoch_loss, time.time()-start))
#     break
    model.steps += 1
    embedder.steps += 1
    
    writer.add_scalar('Training loss', epoch_loss, model.steps+1)

    # Log results and do garbage collection
    if (t+1)%10 == 0:
#         torch.save(embedder.state_dict(), os.path.join(savePath, '{}.embedder'.format(t+1)))
#         torch.save(model.state_dict(), os.path.join(savePath, '{}.model'.format(t+1)))
        
        checkpoint_model = { 
            'model': model,
            'optimizer': optimizer_model.state_dict()
        }

        checkpoint_embed = { 
            'model': embedder,
            'optimizer': optimizer_embed.state_dict()
        }

        torch.save(checkpoint_model, os.path.join(savePath, '{}.model'.format(t+1)))
        torch.save(checkpoint_embed, os.path.join(savePath, '{}.embed'.format(t+1)))
        
#         eng_np = np.array(eng.cpu().detach())
#         fcs_np = np.array(fcs.cpu().detach())
#         true_eng_np = np.array(all_true_eng.cpu().detach())
#         true_fcs_np = np.array(all_true_fcs.cpu().detach())
        
#         eng_np = np.concatenate(epoch_eng)
        fcs_np = np.concatenate(epoch_fcs)
#         true_eng_np = np.concatenate(epoch_true_eng)
        true_fcs_np = np.concatenate(epoch_true_fcs)
        eng_np = np.zeros(1)
        true_eng_np = np.zeros(1)
        
        writer.add_figure(
            'Predictions vs. True',
            plot_pred_vs_true(
                eng_np,
                fcs_np,
                true_eng_np,
                true_fcs_np
            ),
            global_step=model.steps
        )
        
        splitParams = []
        names = []
        for k in embedder.embeddings:
            for d in range(embedder.embeddings[k][2]):
                for i in range(len(embedder.embeddings[k][0])):
                    s, n = embedder._fill(k, d, i)
                    splitParams.append(np.array(s.cpu().detach()))
                    names.append(n)
        
        writer.add_figure(
            'Splines',
            plot_splines(splitParams, names),
            global_step=embedder.steps
        )
        
#         del descriptors_eng
        del descriptors_fcs
        
#         del eng
        del fcs
        
#         del loss_eng
        del loss_fcs
        del loss
        
        del eng_np
        del fcs_np
        del true_eng_np
        del true_fcs_np
        
        del splitParams
        
        gc.collect()

In [None]:
checkpoint_model = { 
    'model': model,
    'optimizer': optimizer_model.state_dict()
}

checkpoint_embed = { 
    'model': embed,
    'optimizer': optimizer_embed.state_dict()
}

torch.save(checkpoint_model, os.path.join(savePath, '{}.model'.format(t+1)))
torch.save(checkpoint_embed, os.path.join(savePath, '{}.embed'.format(t+1)))

In [None]:
#         del descriptors_eng
# del descriptors_fcs

# #         del eng
# del fcs

# #         del loss_eng
# del loss_fcs
# del loss

# del eng_np
# del fcs_np
# del true_eng_np
# del true_fcs_np

# del splitParams
torch.cuda.empty_cache()

In [None]:
# for g in optimizer_embed.param_groups:
for ge, gm in zip(optimizer_embed.param_groups, optimizer_model.param_groups):
    print(ge['lr'], gm['lr'])
    ge['lr'] = 1e-4
    gm['lr'] = 1e-4

In [None]:
fcs_np = np.concatenate(epoch_fcs)
true_fcs_np = np.concatenate(epoch_true_fcs)
eng_np = np.zeros(1)
true_eng_np = np.zeros(1)

plot_pred_vs_true(
    eng_np,
    fcs_np,
    true_eng_np,
    true_fcs_np
)

In [None]:
2.199**2

In [None]:
splitParams = []
names = []
for k in embedder.embeddings:
    for d in range(embedder.embeddings[k][2]):
        for i in range(len(embedder.embeddings[k][0])):
            s, n = embedder._fill(k, d, i)
            splitParams.append(np.array(s.cpu().detach()))
            names.append(n)
            
plot_splines(splitParams, names)