# This notebook shows example to load the CHGNet for prediction


In [None]:
import numpy as np
import torch
from pymatgen.core import Structure

from chgnet.model import CHGNet

# If the above line fails in Google Colab due to numpy version issue,
# please restart the runtime, and the problem will be solved

np.set_printoptions(precision=4, suppress=True)

### Override torch.bincount (not supported by ONNX)

In [None]:
def new_bincount(input_tensor, minlength: int = 0):
    counts = torch.tensor([], dtype=torch.int64)
    output_length = max(minlength, int(torch.max(input_tensor)) + 1)
    for value in range(output_length):
        value_count = (input_tensor == value).sum()
        new_count_tensor = torch.tensor([value_count])
        counts = torch.cat([counts, new_count_tensor])
    return counts

torch.bincount=new_bincount

### Read structure from a json or cif file


In [None]:
structure = Structure.from_file("/Users/chrisfajardo/git_repos/pochi/Li.cif")

### Load Model


In [None]:
chgnet = CHGNet.load()

# Alternatively you can read your own model
# chgnet = CHGNet.from_file(model_path)

In [None]:
graph = chgnet.graph_converter(structure)

In [None]:
chgnet = chgnet.to("cpu")

In [None]:
chgnet = chgnet.eval()

In [None]:
chgnet.forward([list(graph)])

### Create traced model and save TorchScript

In [None]:
traced = torch.jit.trace(chgnet, ([list(graph)],))

In [None]:
traced.save("chgnet.torchscript")

In [None]:
c = torch.jit.load("chgnet.torchscript")

In [None]:
c.forward([list(graph)])

### Try to export to ONNX

In [14]:
torch.onnx.export(traced, [list(graph)], f="chgnet.onnx")

SymbolicValueError: Cannot determine scalar type for this '<class 'torch.TensorType'>' instance and a default value was not provided.  [Caused by the value '182 defined in (%182 : Tensor = onnx::Gather[axis=1](%157, %149), scope: chgnet.model.model.CHGNet:: # /Users/chrisfajardo/git_repos/chgnet/chgnet/model/model.py:875:0
)' (type 'Tensor') in the TorchScript graph. The containing node has kind 'onnx::Gather'.] 
    (node defined in /Users/chrisfajardo/git_repos/chgnet/chgnet/model/model.py(875): get_batch_graph_from_graphs
/Users/chrisfajardo/git_repos/chgnet/chgnet/model/model.py(382): forward
/Users/chrisfajardo/miniconda3/envs/chgnet/lib/python3.9/site-packages/torch/nn/modules/module.py(1522): _slow_forward
/Users/chrisfajardo/miniconda3/envs/chgnet/lib/python3.9/site-packages/torch/nn/modules/module.py(1541): _call_impl
/Users/chrisfajardo/miniconda3/envs/chgnet/lib/python3.9/site-packages/torch/nn/modules/module.py(1532): _wrapped_call_impl
/Users/chrisfajardo/miniconda3/envs/chgnet/lib/python3.9/site-packages/torch/jit/_trace.py(1088): trace_module
/Users/chrisfajardo/miniconda3/envs/chgnet/lib/python3.9/site-packages/torch/jit/_trace.py(820): trace
/var/folders/89/hcxcrt7s2zb8yq5x9mmscspw0000gq/T/ipykernel_4148/2315511256.py(1): <module>
/Users/chrisfajardo/miniconda3/envs/chgnet/lib/python3.9/site-packages/IPython/core/interactiveshell.py(3526): run_code
/Users/chrisfajardo/miniconda3/envs/chgnet/lib/python3.9/site-packages/IPython/core/interactiveshell.py(3466): run_ast_nodes
/Users/chrisfajardo/miniconda3/envs/chgnet/lib/python3.9/site-packages/IPython/core/interactiveshell.py(3284): run_cell_async
/Users/chrisfajardo/miniconda3/envs/chgnet/lib/python3.9/site-packages/IPython/core/async_helpers.py(129): _pseudo_sync_runner
/Users/chrisfajardo/miniconda3/envs/chgnet/lib/python3.9/site-packages/IPython/core/interactiveshell.py(3079): _run_cell
/Users/chrisfajardo/miniconda3/envs/chgnet/lib/python3.9/site-packages/IPython/core/interactiveshell.py(3024): run_cell
/Users/chrisfajardo/miniconda3/envs/chgnet/lib/python3.9/site-packages/ipykernel/zmqshell.py(549): run_cell
/Users/chrisfajardo/miniconda3/envs/chgnet/lib/python3.9/site-packages/ipykernel/ipkernel.py(429): do_execute
/Users/chrisfajardo/miniconda3/envs/chgnet/lib/python3.9/site-packages/ipykernel/kernelbase.py(767): execute_request
/Users/chrisfajardo/miniconda3/envs/chgnet/lib/python3.9/site-packages/ipykernel/kernelbase.py(429): dispatch_shell
/Users/chrisfajardo/miniconda3/envs/chgnet/lib/python3.9/site-packages/ipykernel/kernelbase.py(523): process_one
/Users/chrisfajardo/miniconda3/envs/chgnet/lib/python3.9/site-packages/ipykernel/kernelbase.py(534): dispatch_queue
/Users/chrisfajardo/miniconda3/envs/chgnet/lib/python3.9/asyncio/events.py(80): _run
/Users/chrisfajardo/miniconda3/envs/chgnet/lib/python3.9/asyncio/base_events.py(1905): _run_once
/Users/chrisfajardo/miniconda3/envs/chgnet/lib/python3.9/asyncio/base_events.py(601): run_forever
/Users/chrisfajardo/miniconda3/envs/chgnet/lib/python3.9/site-packages/tornado/platform/asyncio.py(195): start
/Users/chrisfajardo/miniconda3/envs/chgnet/lib/python3.9/site-packages/ipykernel/kernelapp.py(701): start
/Users/chrisfajardo/miniconda3/envs/chgnet/lib/python3.9/site-packages/traitlets/config/application.py(992): launch_instance
/Users/chrisfajardo/miniconda3/envs/chgnet/lib/python3.9/site-packages/ipykernel_launcher.py(17): <module>
/Users/chrisfajardo/miniconda3/envs/chgnet/lib/python3.9/runpy.py(87): _run_code
/Users/chrisfajardo/miniconda3/envs/chgnet/lib/python3.9/runpy.py(197): _run_module_as_main
)

    Inputs:
        #0: 157 defined in (%155 : Tensor, %156 : Tensor, %157 : Tensor, %158 : Tensor, %image : Tensor, %160 : Tensor, %undirected2directed : Tensor, %162 : Tensor, %163 : Tensor, %lattice : Tensor = prim::ListUnpack(%154), scope: chgnet.model.model.CHGNet::
    )  (type 'Tensor')
        #1: 149 defined in (%149 : Long(device=cpu) = onnx::Constant[value={0}](), scope: chgnet.model.model.CHGNet::
    )  (type 'Tensor')
    Outputs:
        #0: 182 defined in (%182 : Tensor = onnx::Gather[axis=1](%157, %149), scope: chgnet.model.model.CHGNet:: # /Users/chrisfajardo/git_repos/chgnet/chgnet/model/model.py:875:0
    )  (type 'Tensor')

In [None]:
#scripted = torch.jit.script(chgnet, example_inputs=[([graph.to_dict()],)])