In [11]:
%load_ext autoreload

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [12]:
%autoreload 2
from typing import Optional
import torch
print(torch.__version__)
import torch.export
import torch.nn as nn
import tqdm
import numpy as np
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import seaborn as sns
import lovelyplots
plt.style.use('ipynb')

import e3nn
e3nn.set_optimization_defaults(jit_script_fx=False)
torch.set_float32_matmul_precision('high')

import e3tools
import rdkit
import rdkit.Chem as Chem

from bond_predictor import models, rdkit_utils

2.6.0+cu124


In [13]:
net = models.E3Conv()
net = torch.compile(net, dynamic=True, fullgraph=True)

In [14]:
weights = torch.load(f'weights_step_{80000}.pt')
net.load_state_dict(weights)

<All keys matched successfully>

In [15]:
class ModelWrapper(torch.nn.Module):
    """Wrapper for export."""

    def __init__(self, net, input_keys, output_keys):
        super(ModelWrapper, self).__init__()
        self.net = net
        self.input_keys = input_keys
        self.output_keys = output_keys
    
    def forward(self, coordinates, atom_types, edge_index):
        return self.net(coordinates, atom_types, edge_index)

In [16]:
output_keys = ['bond_logits', 'charge_logits']
input_keys = ['coordinates', 'atom_types', 'edge_index']

wrapped_net = ModelWrapper(net, input_keys, output_keys)

In [17]:
dummy_data = (torch.randn(10, 3), torch.randint(0, 100, (10,)), torch.randint(0, 10, (2, 15)))

In [18]:
num_nodes = torch.export.Dim("num_nodes", min=0, max=1000000)
num_edges = torch.export.Dim("num_edges", min=0, max=1000000)
dynamic_shapes = {"coordinates": {0: num_nodes}, "atom_types": {0: num_nodes}, "edge_index": {1: num_edges}}

exported = torch.export.export(
    wrapped_net,
    dummy_data,
    dynamic_shapes=dynamic_shapes,
)

In [19]:
output_path = 'bond_predictor.pt2'
out_path = torch._inductor.aoti_compile_and_package(
    exported,
    package_path=output_path,
)

/tmp/torchinductor_ameyad/c7qdxnsfszo7llkcgigaof6rgaspo7tww4tfz73hwnlfmj3njxgd/cp7yaftbo6hb42ptir7ndfhsf5s62zb2aiky6y53vulsfhtva5f7.cpp: In function ‘void cpp_fused_bmm_cat_scatter_reduce_zeros_2(const float*, const int64_t*, const int64_t*, const float*, const float*, float*, float*, float*, float*, float*, int64_t, int64_t)’:
  844 |                     float tmp_acc0_arr[8];
      |                           ^~~~~~~~~~~~
  849 |                     float tmp_acc1_arr[8];
      |                           ^~~~~~~~~~~~
/tmp/torchinductor_ameyad/c7qdxnsfszo7llkcgigaof6rgaspo7tww4tfz73hwnlfmj3njxgd/cp7yaftbo6hb42ptir7ndfhsf5s62zb2aiky6y53vulsfhtva5f7.cpp: In function ‘void cpp_fused_add_bmm_mul_5(const float*, const int64_t*, const float*, const float*, float*, float*, float*, float*, float*, float*, int64_t, int64_t)’:
 1166 |                     float tmp_acc0_arr[8];
      |                           ^~~~~~~~~~~~
 1171 |                     float tmp_acc1_arr[8];
      |             

In [20]:
aot_model = torch._inductor.aoti_load_package(out_path)
aot_out = aot_model(dummy_data)
bond_logits, charge_logits = aot_out