**This notebook provides examples of how to verify the performance of GCNNs on the validation set (function: reverify_sigopt_models), select the top-performing models accordingly (function: keep_the_best_few_models), compute the prediction on the test and holdout sets (function: get_all_model_predictions), and extract the latent embeddings of CGCNN and e3nn after all message passing and graph convolution layers (function: get_all_embeddings).**

Parameters:
- struct_type: the structure representation to use (options: unrelaxed, relaxed, M3Gnet_relaxed)
- model_type: the model architechture to use (options: CGCNN, e3nn, Painn)
- gpu_num: the GPU to use
- training_fraction: if not trained on the entire training set, the fraction of the training set to use
- num_best_models: the number of top-performing models to use

In [1]:
from inference.select_best_models import reverify_wandb_models, keep_the_best_few_models
from inference.test_model_prediction import get_all_model_predictions
from inference.embedding_extraction import get_all_embeddings

  from .autonotebook import tqdm as notebook_tqdm


# CGCNN

In [2]:
reverify_wandb_models(
    model_params={
        "struct_type": "unrelaxed",
        "model_type": "CGCNN",
        "training_fraction":1.0,
    },
    gpu_num=0
)

Loaded data
Completed data processing


100%|██████████| 6276/6276 [00:00<00:00, 8605.95it/s]
100%|██████████| 1277/1277 [00:00<00:00, 8486.59it/s]


Reverifying wandb model #0 (observ_09b2uh9s)


100%|██████████| 6276/6276 [01:27<00:00, 71.89it/s] 


Reverifying wandb model #1 (observ_2bwzbm4a)


100%|██████████| 6276/6276 [00:00<00:00, 28012.25it/s]


Reverifying wandb model #2 (observ_2jcubzj4)


100%|██████████| 6276/6276 [00:00<00:00, 29388.66it/s]


Reverifying wandb model #3 (observ_2s73ddgv)


100%|██████████| 6276/6276 [00:00<00:00, 28516.61it/s]


Reverifying wandb model #4 (observ_3oiczxq2)


100%|██████████| 6276/6276 [00:00<00:00, 29691.78it/s]


Reverifying wandb model #5 (observ_49tl4l12)


100%|██████████| 6276/6276 [00:00<00:00, 29821.48it/s]


Reverifying wandb model #6 (observ_4p8z3i0m)


100%|██████████| 6276/6276 [00:00<00:00, 29761.03it/s]


Reverifying wandb model #7 (observ_54cqu9ml)


100%|██████████| 6276/6276 [00:00<00:00, 29479.43it/s]


Reverifying wandb model #8 (observ_5d1b1xg1)


100%|██████████| 6276/6276 [00:00<00:00, 29624.24it/s]


Reverifying wandb model #9 (observ_5p8h9jh5)


100%|██████████| 6276/6276 [00:00<00:00, 29662.33it/s]


Reverifying wandb model #10 (observ_6plr3lcq)


100%|██████████| 6276/6276 [00:00<00:00, 29276.42it/s]


Reverifying wandb model #11 (observ_80rmxv0j)


100%|██████████| 6276/6276 [00:00<00:00, 29411.02it/s]


Reverifying wandb model #12 (observ_8ba9sw4k)


100%|██████████| 6276/6276 [00:00<00:00, 29457.56it/s]


Reverifying wandb model #13 (observ_8bhrm4hf)


100%|██████████| 6276/6276 [00:00<00:00, 29433.71it/s]


Reverifying wandb model #14 (observ_9u97mm6a)


100%|██████████| 6276/6276 [00:00<00:00, 29421.74it/s]


Reverifying wandb model #15 (observ_a45136sx)


100%|██████████| 6276/6276 [00:00<00:00, 29636.25it/s]


Reverifying wandb model #16 (observ_adi6wl92)


100%|██████████| 6276/6276 [00:00<00:00, 29522.54it/s]


Reverifying wandb model #17 (observ_augud3u4)


100%|██████████| 6276/6276 [00:00<00:00, 29413.45it/s]


Reverifying wandb model #18 (observ_bcwl2zpk)


100%|██████████| 6276/6276 [00:00<00:00, 29637.65it/s]


Reverifying wandb model #19 (observ_cdccpbp8)


100%|██████████| 6276/6276 [00:00<00:00, 29391.42it/s]


Reverifying wandb model #20 (observ_ckylh4ug)


100%|██████████| 6276/6276 [00:00<00:00, 29807.97it/s]


Reverifying wandb model #21 (observ_cwcwhkko)


100%|██████████| 6276/6276 [00:00<00:00, 28697.98it/s]


Reverifying wandb model #22 (observ_d5v1z4ej)


100%|██████████| 6276/6276 [00:00<00:00, 29584.49it/s]


Reverifying wandb model #23 (observ_du222f4z)


100%|██████████| 6276/6276 [00:00<00:00, 29658.46it/s]


Reverifying wandb model #24 (observ_dvpqu60a)


100%|██████████| 6276/6276 [00:00<00:00, 29495.91it/s]


Reverifying wandb model #25 (observ_enpygz6a)


100%|██████████| 6276/6276 [00:00<00:00, 29322.04it/s]


Reverifying wandb model #26 (observ_f9v9jrkm)


100%|██████████| 6276/6276 [00:00<00:00, 29694.39it/s]


Reverifying wandb model #27 (observ_hi3k26oe)


100%|██████████| 6276/6276 [00:00<00:00, 29463.53it/s]


Reverifying wandb model #28 (observ_hyr6xdjg)


100%|██████████| 6276/6276 [00:00<00:00, 28107.61it/s]


Reverifying wandb model #29 (observ_imb9ao61)


100%|██████████| 6276/6276 [00:00<00:00, 6368.81it/s]


Reverifying wandb model #30 (observ_jisd1e8m)


100%|██████████| 6276/6276 [00:00<00:00, 29735.21it/s]


Reverifying wandb model #31 (observ_jyjgs8ey)


100%|██████████| 6276/6276 [00:00<00:00, 29546.77it/s]


Reverifying wandb model #32 (observ_kfkaui6j)


100%|██████████| 6276/6276 [00:00<00:00, 29387.05it/s]


Reverifying wandb model #33 (observ_km814edf)


100%|██████████| 6276/6276 [00:00<00:00, 29381.97it/s]


Reverifying wandb model #34 (observ_l4gxd8lw)


100%|██████████| 6276/6276 [00:00<00:00, 27695.87it/s]


Reverifying wandb model #35 (observ_lm0yimqa)


100%|██████████| 6276/6276 [00:00<00:00, 29639.82it/s]


Reverifying wandb model #36 (observ_mhickp7v)


100%|██████████| 6276/6276 [00:00<00:00, 29755.18it/s]


Reverifying wandb model #37 (observ_n2phdmek)


100%|██████████| 6276/6276 [00:00<00:00, 29677.38it/s]


Reverifying wandb model #38 (observ_npavtj6h)


100%|██████████| 6276/6276 [00:00<00:00, 29506.10it/s]


Reverifying wandb model #39 (observ_nsu2734z)


100%|██████████| 6276/6276 [00:00<00:00, 28249.50it/s]


Reverifying wandb model #40 (observ_qf9pdl27)


100%|██████████| 6276/6276 [00:00<00:00, 29341.82it/s]


Reverifying wandb model #41 (observ_r7yaxfue)


100%|██████████| 6276/6276 [00:00<00:00, 28644.02it/s]


Reverifying wandb model #42 (observ_reblty69)


100%|██████████| 6276/6276 [00:00<00:00, 29670.22it/s]


Reverifying wandb model #43 (observ_s1dx510i)


100%|██████████| 6276/6276 [00:00<00:00, 29651.01it/s]


Reverifying wandb model #44 (observ_vfu3r9k7)


100%|██████████| 6276/6276 [00:00<00:00, 29556.89it/s]


Reverifying wandb model #45 (observ_w3yiycgl)


100%|██████████| 6276/6276 [00:00<00:00, 29596.80it/s]


Reverifying wandb model #46 (observ_x8g0t2zu)


100%|██████████| 6276/6276 [00:00<00:00, 28785.44it/s]


Reverifying wandb model #47 (observ_x9qeeoja)


100%|██████████| 6276/6276 [00:00<00:00, 29661.36it/s]


Reverifying wandb model #48 (observ_xbnsp9u1)


100%|██████████| 6276/6276 [00:00<00:00, 28617.05it/s]


Reverifying wandb model #49 (observ_zg25g5s9)


100%|██████████| 6276/6276 [00:00<00:00, 29693.22it/s]


In [4]:
keep_the_best_few_models(
    model_params={
        "struct_type": "unrelaxed",
        "model_type": "CGCNN",
        "training_fraction":1.0,
    },
    num_best_models=3
)

Copied model observ_du222f4z to best_0
Copied model observ_augud3u4 to best_1
Copied model observ_jisd1e8m to best_2
Kept the best 3 models in ./best_models/CGCNN/dft_e_hull_htvs_data_unrelaxed_CGCNN


In [5]:
get_all_model_predictions(
    model_params={
        "struct_type": "unrelaxed",
        "model_type": "CGCNN",
        "training_fraction":1.0,
    },
    gpu_num=0,
    num_best_models=3
)

Loaded data
Completed data processing


100%|██████████| 6276/6276 [00:00<00:00, 8582.70it/s]
100%|██████████| 1261/1261 [00:00<00:00, 8386.20it/s]
100%|██████████| 6276/6276 [01:28<00:00, 70.77it/s] 


Timing...
19.374361753463745
19.77317476272583


100%|██████████| 6276/6276 [00:00<00:00, 28517.56it/s]


Timing...
0.6810338497161865
1.0798468589782715


100%|██████████| 6276/6276 [00:00<00:00, 28278.88it/s]


Timing...
0.6692249774932861
1.068037986755371
Completed model prediction for test_set
Loaded data
Completed data processing


100%|██████████| 6276/6276 [00:00<00:00, 8342.05it/s]
100%|██████████| 600/600 [00:00<00:00, 8581.52it/s]
100%|██████████| 6276/6276 [01:29<00:00, 70.27it/s] 


Timing...
10.09644889831543
10.2840576171875


100%|██████████| 6276/6276 [00:00<00:00, 26744.11it/s]


Timing...
0.3335995674133301
0.5212082862854004


100%|██████████| 6276/6276 [00:00<00:00, 28275.02it/s]


Timing...
0.32080578804016113
0.5084145069122314
Completed model prediction for holdout_set_B_sites
Loaded data
Completed data processing


100%|██████████| 6276/6276 [00:00<00:00, 8295.12it/s]
100%|██████████| 863/863 [00:00<00:00, 8340.32it/s]
100%|██████████| 6276/6276 [01:28<00:00, 71.31it/s] 


Timing...
13.80959701538086
14.086063623428345


100%|██████████| 6276/6276 [00:00<00:00, 28615.84it/s]


Timing...
0.4729421138763428
0.7494087219238281


100%|██████████| 6276/6276 [00:00<00:00, 28612.60it/s]


Timing...
0.4615757465362549
0.7380423545837402
Completed model prediction for holdout_set_series


In [2]:
get_all_embeddings(
    model_params={
        "struct_type": "unrelaxed",
        "model_type": "CGCNN",
        "training_fraction":1.0,
    },
    gpu_num=0,
    num_best_models=3
)

Loaded data
Completed data processing


100%|██████████| 6276/6276 [00:00<00:00, 8348.79it/s]
100%|██████████| 1261/1261 [00:00<00:00, 8336.74it/s]
100%|██████████| 6276/6276 [01:28<00:00, 71.25it/s] 
100%|██████████| 6276/6276 [00:00<00:00, 30239.53it/s]
100%|██████████| 6276/6276 [00:00<00:00, 29867.30it/s]


Completed embedding extraction for test_set
Loaded data
Completed data processing


100%|██████████| 6276/6276 [00:00<00:00, 8359.19it/s]
100%|██████████| 600/600 [00:00<00:00, 8386.85it/s]
100%|██████████| 6276/6276 [01:29<00:00, 70.25it/s] 
100%|██████████| 6276/6276 [00:00<00:00, 29880.65it/s]
100%|██████████| 6276/6276 [00:00<00:00, 29005.38it/s]


Completed embedding extraction for holdout_set_B_sites
Loaded data
Completed data processing


100%|██████████| 6276/6276 [00:00<00:00, 8363.62it/s]
100%|██████████| 863/863 [00:00<00:00, 8363.10it/s]
100%|██████████| 6276/6276 [01:31<00:00, 68.90it/s] 
100%|██████████| 6276/6276 [00:00<00:00, 30298.95it/s]
100%|██████████| 6276/6276 [00:00<00:00, 29643.19it/s]


Completed embedding extraction for holdout_set_series


# e3nn

In [None]:
reverify_wandb_models(
    model_params={
        "struct_type": "relaxed",
        "model_type": "e3nn",
        "training_fraction":0.5,
    },
    gpu_num=0
)

In [None]:
keep_the_best_few_models(
    model_params={
        "struct_type": "relaxed",
        "model_type": "e3nn",
        "training_fraction":0.5,
    },
    num_best_models=3
)

In [None]:
get_all_model_predictions(
    model_params={
        "struct_type": "relaxed",
        "model_type": "e3nn",
        "training_fraction":0.5,
    },
    gpu_num=0,
    num_best_models=3
)

In [None]:
get_all_embeddings(
    model_params={
        "struct_type": "relaxed",
        "model_type": "e3nn",
        "training_fraction":0.5,
    },
    gpu_num=0,
    num_best_models=3
)

# Painn

In [None]:
reverify_wandb_models(
    model_params={
        "struct_type": "unrelaxed",
        "model_type": "Painn",
        "training_fraction":1.0,
    },
    gpu_num=0
)


In [None]:
keep_the_best_few_models(
    model_params={
        "struct_type": "unrelaxed",
        "model_type": "Painn",
        "training_fraction":1.0,
    },
    num_best_models=3
)

In [None]:
get_all_model_predictions(
    model_params={
        "struct_type": "unrelaxed",
        "model_type": "Painn",
        "training_fraction":1.0,
    },
    gpu_num=0,
    num_best_models=3
)