In [9]:
import torch
from torch import tensor
import numpy as np
from functools import partial
from opt_einsum import contract

In [26]:
def ae_loss(ref_dict,pred_dict, loss, **kwargs):
    """ae_loss(ref_dict, pred_dict, loss, **kwargs):
    
        Calulates atomization energy loss from reference values.

    Args:
        ref_dict ([dict]): A dictionary of reference atomization energies, whose values are flattened to a list.
        pred_dict ([dict]): A dictionary of predicted atomization energies, whose values are flattened to a list.
        loss (callable): Callable loss function
        weights (torch.Tensor) [optional]: if specified, scale individual energy differences.
            defaults to a linspace of weights from 0 to 1 of size results['E'], or 1 if only one prediction.

    Returns:
        [?]: loss called on weighted difference between reference and prediction
    """
    print("AE_LOSS FUNCTION")
    print("INPUT REF/PRED: ")
    print("REF: {}".format(ref_dict))
    print("PRED: {}".format(pred_dict))
    print("Flattening ref_dict, pred_dict")
    #ref = torch.cat(list(atomization_energies(ref_dict).values()))
    atm_pred = atomization_energies(pred_dict)
    ref = ref_dict[list(atm_pred.keys())[0]]
    pred = torch.cat(list(atm_pred.values()))
    assert len(ref) == 1
    ref = ref.expand(pred.size()[0])
    if pred.size()[0] > 1:
        weights = kwargs.get('weights', torch.linspace(0,1,pred.size()[0])**2).to(pred.device)
    else:
        weights = 1
    lae = loss((ref-pred)*weights,torch.zeros_like(pred))
    print("AE LOSS: {}".format(lae))
    return lae


def atomization_energies(energies):
    """Calculates atomization energies based on a dictionary of molecule/atomic energies.
    
    energies['ABCD'] = molecular energy
    energies['A'], energies['B'], etc. = atomic energy.
    
    Loops over ABCD - A - B - C - D

    Args:
        energies (dict): dictionary of molecule and constituent atomic energies.
    """
    def split(el):
        """Regex split molecule's symbolic expansion into constituent elements.
        No numbers must be present -- CH2 = CHH.

        Args:
            el (str): Molecule symbols

        Returns:
            list: list of individual atoms in molecule
        """
        import re
        res_list = [s for s in re.split("([A-Z][^A-Z]*)", el) if s]
        return res_list


    ae = {}
    for key in energies:
        if isinstance(energies[key],torch.Tensor):
            #if len(split(key)) == 1:continue
            e_tot = torch.clone(energies[key])
            e_tot_size = e_tot.size()
        else:
            e_tot = np.array(energies[key])
            e_tot_size = e_tot.shape
        for symbol in split(key):
            #if single atom, continue
            if len(split(key)) == 1: continue
            e_sub = energies[symbol]
            e_sub_size = e_sub.size() if isinstance(e_sub, torch.Tensor) else e_sub.shape
            if e_tot_size == e_sub_size:
                e_tot -= e_sub
            else:
                e_tot -= e_sub[-1:]
            print('{} - {}: {}'.format(key, symbol, e_tot))
            ae[key] = e_tot
    if ae == {}:
        #empty dict -- no splitting occurred, so single atom
        ae[key] = e_tot
    print("Atomization Energy Final")
    print(ae)
    return ae

In [32]:
rd = {'CCHH': tensor([-0.6459]), 'C': tensor([-37.8405]), 'H': tensor([-0.5000])}
pd = {'CCHH': tensor([-77.1652, -77.1651, -77.1651, -77.1650, -77.1650, -77.1650, -77.1650,
        -77.1650, -77.1650, -77.1650]), 'C': tensor([-37.7855, -37.7856, -37.7857, -37.7857, -37.7857, -37.7857, -37.7858,
        -37.7858, -37.7858, -37.7858]), 'H': tensor([-0.4965, -0.4966, -0.4966, -0.4967, -0.4967, -0.4967, -0.4968, -0.4968,
        -0.4968, -0.4968])}
atomization_energies(pd)

CCHH - C: tensor([-39.3797, -39.3795, -39.3794, -39.3793, -39.3793, -39.3793, -39.3792,
        -39.3792, -39.3792, -39.3792])
CCHH - C: tensor([-1.5942, -1.5939, -1.5937, -1.5936, -1.5936, -1.5936, -1.5934, -1.5934,
        -1.5934, -1.5934])
CCHH - H: tensor([-1.0977, -1.0973, -1.0971, -1.0969, -1.0969, -1.0969, -1.0966, -1.0966,
        -1.0966, -1.0966])
CCHH - H: tensor([-0.6012, -0.6007, -0.6005, -0.6002, -0.6002, -0.6002, -0.5998, -0.5998,
        -0.5998, -0.5998])
Atomization Energy Final
{'CCHH': tensor([-0.6012, -0.6007, -0.6005, -0.6002, -0.6002, -0.6002, -0.5998, -0.5998,
        -0.5998, -0.5998])}


{'CCHH': tensor([-0.6012, -0.6007, -0.6005, -0.6002, -0.6002, -0.6002, -0.5998, -0.5998,
         -0.5998, -0.5998])}

In [28]:
ae_loss = partial(ae_loss,loss = torch.nn.MSELoss())

In [29]:
ae_loss(rd, pd)

AE_LOSS FUNCTION
INPUT REF/PRED: 
REF: {'CCHH': tensor([-0.6459]), 'C': tensor([-37.8405]), 'H': tensor([-0.5000])}
PRED: {'CCHH': tensor([-77.1652, -77.1651, -77.1651, -77.1650, -77.1650, -77.1650, -77.1650,
        -77.1650, -77.1650, -77.1650]), 'C': tensor([-37.7855, -37.7856, -37.7857, -37.7857, -37.7857, -37.7857, -37.7858,
        -37.7858, -37.7858, -37.7858]), 'H': tensor([-0.4965, -0.4966, -0.4966, -0.4967, -0.4967, -0.4967, -0.4968, -0.4968,
        -0.4968, -0.4968])}
Flattening ref_dict, pred_dict
CCHH - C: tensor([-39.3797, -39.3795, -39.3794, -39.3793, -39.3793, -39.3793, -39.3792,
        -39.3792, -39.3792, -39.3792])
CCHH - C: tensor([-1.5942, -1.5939, -1.5937, -1.5936, -1.5936, -1.5936, -1.5934, -1.5934,
        -1.5934, -1.5934])
CCHH - H: tensor([-1.0977, -1.0973, -1.0971, -1.0969, -1.0969, -1.0969, -1.0966, -1.0966,
        -1.0966, -1.0966])
CCHH - H: tensor([-0.6012, -0.6007, -0.6005, -0.6002, -0.6002, -0.6002, -0.5998, -0.5998,
        -0.5998, -0.5998])
Atomiz

tensor(0.0005)

In [35]:
rd = {'NOO': tensor([-0.3634]), 'N': tensor([-54.5892]), 'O': tensor([-75.0673])}
pd = {'NOO': tensor([-204.8356]), 'N': tensor([-54.5133, -54.5134, -54.5136, -54.5137, -54.5137, -54.5138, -54.5138,
        -54.5138, -54.5138, -54.5138]), 'O': tensor([-74.9761, -74.9762, -74.9762, -74.9763, -74.9763, -74.9764, -74.9764,
        -74.9764, -74.9764, -74.9764])}
atomization_energies(pd)
ae_loss(rd, pd)

NOO - N: tensor([-150.3218])
NOO - O: tensor([-75.3454])
NOO - O: tensor([-0.3690])
Atomization Energy Final
{'NOO': tensor([-0.3690])}
AE_LOSS FUNCTION
INPUT REF/PRED: 
REF: {'NOO': tensor([-0.3634]), 'N': tensor([-54.5892]), 'O': tensor([-75.0673])}
PRED: {'NOO': tensor([-204.8356]), 'N': tensor([-54.5133, -54.5134, -54.5136, -54.5137, -54.5137, -54.5138, -54.5138,
        -54.5138, -54.5138, -54.5138]), 'O': tensor([-74.9761, -74.9762, -74.9762, -74.9763, -74.9763, -74.9764, -74.9764,
        -74.9764, -74.9764, -74.9764])}
Flattening ref_dict, pred_dict
NOO - N: tensor([-150.3218])
NOO - O: tensor([-75.3454])
NOO - O: tensor([-0.3690])
Atomization Energy Final
{'NOO': tensor([-0.3690])}
AE LOSS: 3.139678665320389e-05


tensor(3.1397e-05)

In [36]:
rd = {'OHH': tensor([-0.3713]), 'O': tensor([-75.0673]), 'H': tensor([-0.5000])}
pd = {'OHH': tensor([-76.3291]), 'O': tensor([-74.9773, -74.9774, -74.9775, -74.9776, -74.9776, -74.9776, -74.9776,
        -74.9776, -74.9777, -74.9777]), 'H': tensor([-0.4966, -0.4967, -0.4967, -0.4968, -0.4968, -0.4969, -0.4969, -0.4969,
        -0.4969, -0.4969])}
atomization_energies(pd)
ae_loss(rd, pd)

OHH - O: tensor([-1.3514])
OHH - H: tensor([-0.8545])
OHH - H: tensor([-0.3576])
Atomization Energy Final
{'OHH': tensor([-0.3576])}
AE_LOSS FUNCTION
INPUT REF/PRED: 
REF: {'OHH': tensor([-0.3713]), 'O': tensor([-75.0673]), 'H': tensor([-0.5000])}
PRED: {'OHH': tensor([-76.3291]), 'O': tensor([-74.9773, -74.9774, -74.9775, -74.9776, -74.9776, -74.9776, -74.9776,
        -74.9776, -74.9777, -74.9777]), 'H': tensor([-0.4966, -0.4967, -0.4967, -0.4968, -0.4968, -0.4969, -0.4969, -0.4969,
        -0.4969, -0.4969])}
Flattening ref_dict, pred_dict
OHH - O: tensor([-1.3514])
OHH - H: tensor([-0.8545])
OHH - H: tensor([-0.3576])
Atomization Energy Final
{'OHH': tensor([-0.3576])}
AE LOSS: 0.00018762654508464038


tensor(0.0002)

In [15]:
t = atomization_energies(pd)

CCHH - C: tensor([-39.3797, -39.3795, -39.3794, -39.3793, -39.3793, -39.3793, -39.3792,
        -39.3792, -39.3792, -39.3792])
CCHH - C: tensor([-1.5942, -1.5939, -1.5937, -1.5936, -1.5936, -1.5936, -1.5934, -1.5934,
        -1.5934, -1.5934])
CCHH - H: tensor([-1.0977, -1.0973, -1.0971, -1.0969, -1.0969, -1.0969, -1.0966, -1.0966,
        -1.0966, -1.0966])
CCHH - H: tensor([-0.6012, -0.6007, -0.6005, -0.6002, -0.6002, -0.6002, -0.5998, -0.5998,
        -0.5998, -0.5998])
Atomization Energy Final
{'CCHH': tensor([-0.6012, -0.6007, -0.6005, -0.6002, -0.6002, -0.6002, -0.5998, -0.5998,
        -0.5998, -0.5998])}


In [21]:
t.keys()

AttributeError: 'dict_keys' object has no attribute 'values'