In [203]:
from typing import Dict
from pathlib import Path
import pandas as pd
import numpy as np

In [18]:
data = pd.read_csv(Path('../datasets/processed/dichalcogenides_x1s6_202109_MoS2/targets.csv.gz'), index_col=0)

In [19]:
data

Unnamed: 0_level_0,energy,energy_per_atom,formation_energy,formation_energy_per_site,band_gap,homo,lumo,fermi_level
_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
6141cf0efbfd4bd9ab2c2f7e,-1391.3404,-7.284505,2.6457,2.645700,1.1452,-0.6754,0.4698,-0.199707
6141cf0f51c1cbd9654b8870,-1384.5528,-7.287120,5.3063,2.653150,1.0843,-0.6852,0.3991,-0.220627
6141cf0fe689ecc4c43cdd4b,-1397.1961,-7.277063,0.2790,0.279000,1.8033,-0.6931,1.1102,-0.183537
6141cf10b842c2e72e2f2d44,-1396.2576,-7.272175,0.5795,0.289750,1.8095,-0.6916,1.1179,-0.179802
6141cf1051c1cbd9654b8872,-1384.5327,-7.287014,5.3264,2.663200,1.1102,-0.6718,0.4384,-0.213190
...,...,...,...,...,...,...,...,...
6148f3f63ac25c70a5c6cdff,-1366.4702,-7.230001,12.4557,4.151900,0.3526,-0.5351,-0.1825,-0.359015
6149087231cf3ef3d4a9f848,-1372.5659,-7.224031,9.8490,3.283000,0.3002,-0.4501,-0.1499,-0.300181
6149c48031cf3ef3d4a9f84a,-1372.2370,-7.222300,10.1779,3.392633,0.3594,-0.5045,-0.1451,-0.324836
6149f3853ac25c70a5c6ce01,-1367.4786,-7.235337,11.4473,3.815767,0.5270,-0.6883,-0.1613,-0.424306


In [45]:
def energy_within_threshold(prediction, target):
    # compute absolute error on energy per system.
    # then count the no. of systems where max energy error is < 0.000002.
    e_thresh = 0.02
    error_energy = np.abs(target - prediction)

    success = np.sum(error_energy < e_thresh)
    total = target.shape[0]
    return success / total


In [244]:
class ResultTable:
    def __init__(self):
        self.predictions = Path('../datasets/predictions').resolve()

        self.EwT = {}
        self.MAE = {}
        self.prety_names = {
            'homo': 'HOMO',
            'lumo': 'LUMO',
            'band_gap': 'Bandgap',
            'formation_energy_per_site': 'Formation',
            }



        self.get_results()


    def update_dict(self, name, energy_type, this_ewt: float, this_mae: float, ewt: Dict, mae: Dict):
       ewt[energy_type] = this_ewt
       mae[energy_type] = this_mae
       self.EwT[name] = ewt
       self.MAE[name] = mae 

    def filter_apply(self, item, energy_type, this_ewt, this_mae, ewt: Dict, mae: Dict):
        fn = lambda name: self.update_dict(name, energy_type, this_ewt, this_mae, ewt, mae)

        if 'catboost' in item:
            fn("Catboost+matminer")

        if 'gemnet' in item:
            fn("GEMNet")
        
        if 'MoS2-plain-cv' == item:
            fn("MEGNet-full")

        if 'sparse' == item:
            fn("MEGNet-sparse")


    def get_results(self):
        for item in ['MoS2-plain-cv-catboost', 'MoS2-plain-cv-gemnet', 'MoS2-plain-cv', 'sparse']:
            ewt = {}
            mae = {}
            for e in ['homo', 'band_gap', 'formation_energy_per_site']:
                if item == 'sparse':
                    df_pred = pd.read_csv(list(self.predictions.joinpath('MoS2-plain-cv', e).glob('*sparse*.csv.gz'))[0], index_col=0)
                    df_pred = df_pred.assign(target=data[e])
                
                else: 
                    df_pred = pd.read_csv(list(self.predictions.joinpath(item, e).iterdir())[0], index_col=0)
                    df_pred = df_pred.assign(target=data[e])

                this_ewt = energy_within_threshold(df_pred[f'predicted_{e}_test'], df_pred['target'])
                this_mae = np.abs(df_pred[f'predicted_{e}_test'] - df_pred['target']).mean()
                

                self.filter_apply(item, self.prety_names[e], this_ewt, this_mae, ewt, mae)

    def print_table(self):
        mae = pd.DataFrame.from_dict(self.MAE, orient='index')
        ewt = pd.DataFrame.from_dict(self.EwT, orient='index')

        table = pd.concat([mae, ewt], axis=1)
        table.columns=[[r'MAE (eV) $\downarrow$'] * 3 +  [r'EwT (\%) $\uparrow$'] * 3, table.columns]
        print(table.to_latex(escape=False, multicolumn_format='c', column_format='lllllll'))
        return table

ResultTable().print_table()

\begin{tabular}{lllllll}
\toprule
{} & \multicolumn{3}{c}{MAE (eV) $\downarrow$} & \multicolumn{3}{c}{EwT (\%) $\uparrow$} \\
{} &                  HOMO &   Bandgap & Formation &                HOMO &   Bandgap & Formation \\
\midrule
Catboost+matminer &              0.006814 &  0.010369 &  0.007677 &            0.894488 &  0.832968 &  0.946233 \\
GEMNet            &              0.018737 &  0.065087 &  0.023933 &            0.653632 &  0.445306 &  0.827575 \\
MEGNet-full       &              0.007566 &  0.050782 &  0.055232 &            0.892466 &  0.498736 &  0.337266 \\
MEGNet-sparse     &              0.004868 &  0.007194 &  0.007105 &            0.950615 &  0.917411 &  0.953312 \\
\bottomrule
\end{tabular}



Unnamed: 0_level_0,MAE (eV) $\downarrow$,MAE (eV) $\downarrow$,MAE (eV) $\downarrow$,EwT (\%) $\uparrow$,EwT (\%) $\uparrow$,EwT (\%) $\uparrow$
Unnamed: 0_level_1,HOMO,Bandgap,Formation,HOMO,Bandgap,Formation
Catboost+matminer,0.006814,0.010369,0.007677,0.894488,0.832968,0.946233
GEMNet,0.018737,0.065087,0.023933,0.653632,0.445306,0.827575
MEGNet-full,0.007566,0.050782,0.055232,0.892466,0.498736,0.337266
MEGNet-sparse,0.004868,0.007194,0.007105,0.950615,0.917411,0.953312
