In [34]:
import yaml
from matdeeplearn.preprocessor.processor import process_data

with open("../configs/config.yml", 'r') as ymlfile:
    cfg = yaml.safe_load(ymlfile)

dataset = process_data(cfg["dataset"])

In [35]:
import torch
from matdeeplearn.models.cgcnn import CGCNN
from matdeeplearn.common.ase_utils import MDLCalculator

model = MDLCalculator._load_model(cfg, rank=0)[0]
model

QCGCNN(
  (distance_expansion): GaussianSmearing()
  (pre_lin_list): ModuleList(
    (0): Linear(in_features=100, out_features=100, bias=True)
  )
  (conv_list): ModuleList(
    (0-3): 4 x QCGConv(100, dim=50)
  )
  (bn_list): ModuleList()
  (post_lin_list): ModuleList(
    (0): Linear(in_features=100, out_features=150, bias=True)
    (1-2): 2 x Linear(in_features=150, out_features=150, bias=True)
  )
  (lin_out): Linear(in_features=150, out_features=1, bias=True)
)

In [36]:
model_pth = "/net/csefiles/coc-fung-cluster/Qianyu/quant/MatDeepLearn_dev/results/2024-09-02-18-05-27-685-my_train_job/checkpoint_0/best_checkpoint.pt"
state_dict = torch.load(model_pth)["state_dict"]
model.load_state_dict(state_dict)

<All keys matched successfully>

In [37]:
from torch_geometric.loader import DataLoader

loader = DataLoader(dataset["full"], batch_size=32, shuffle=False)
batch = next(iter(loader))

In [38]:
import torch
import torch.nn as nn
import torch.quantization as quant

per_tensor_qconfig = quant.QConfig(
    activation=quant.MinMaxObserver.with_args(dtype=torch.quint8),
    weight=quant.MinMaxObserver.with_args(dtype=torch.qint8)
)

In [39]:
model = model.to("cpu")

In [40]:
%%timeit
with torch.no_grad():
    model.eval()
    for batch in loader:
        out = model(batch)

386 ms ± 3.01 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [42]:
model.pre_lin_list.qconfig = per_tensor_qconfig
model.post_lin_list.qconfig = per_tensor_qconfig
model.conv_list.qconfig = per_tensor_qconfig
model.lin_out.qconfig = per_tensor_qconfig

quant.prepare(model, inplace=True)

# Calibrate the model (run some sample data through it)
# ... (calibration code here)
def calibrate(model, data_loader):
    model.eval()
    with torch.no_grad():
        for batch in data_loader:
            model(batch)

# Run calibration
model = model.to("cpu")
calibrate(model, loader)

# Convert the model to quantized version
quant.convert(model, inplace=True)



QCGCNN(
  (distance_expansion): GaussianSmearing()
  (pre_lin_list): ModuleList(
    (0): QuantizedLinear(in_features=100, out_features=100, scale=0.001273611793294549, zero_point=160, qscheme=torch.per_tensor_affine)
  )
  (conv_list): ModuleList(
    (0-3): 4 x QCGConv(100, dim=50)
  )
  (bn_list): ModuleList()
  (post_lin_list): ModuleList(
    (0): QuantizedLinear(in_features=100, out_features=150, scale=0.05809524282813072, zero_point=115, qscheme=torch.per_tensor_affine)
    (1): QuantizedLinear(in_features=150, out_features=150, scale=0.03292454779148102, zero_point=88, qscheme=torch.per_tensor_affine)
    (2): QuantizedLinear(in_features=150, out_features=150, scale=0.029851863160729408, zero_point=95, qscheme=torch.per_tensor_affine)
  )
  (lin_out): QuantizedLinear(in_features=150, out_features=1, scale=0.016913220286369324, zero_point=255, qscheme=torch.per_tensor_affine)
)

In [43]:
batch = next(iter(loader))
batch

DataBatch(n_atoms=[32], pos=[320, 3], cell=[32, 3, 3], structure_id=[32], z=[320], u=[32, 3], y=[32, 1], edge_index=[2, 2704], edge_weight=[2704], edge_vec=[2704, 3], cell_offsets=[2704, 3], neighbors=[0], x=[320, 100], edge_attr=[2704, 50], batch=[320], ptr=[33])

In [44]:
import torch

def quantize_to_quint8(tensor):
    min_val, max_val = tensor.min(), tensor.max()
    
    # Calculate scale and zero_point
    scale = (max_val - min_val) / 255.0
    zero_point = (-min_val / scale).round().clamp(0, 255).to(torch.int)

    # Quantize
    q_tensor = torch.quantize_per_tensor(tensor, scale.item(), int(zero_point), torch.quint8)
    return q_tensor

def quantize_batch(batch):
    for key, value in batch.items():
        if key == 'x' or key == 'pos' or key == "cell" or key == 'edge_weight' or key == 'edge_attr':
            setattr(batch, key, quantize_to_quint8(value))
    return batch

In [45]:
%%timeit
with torch.no_grad():
    model.eval()
    for batch in loader:
        model(quantize_batch(batch), inference=True)

340 ms ± 9.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [46]:
model(quantize_batch(next(iter(loader))), inference=True)

{'output': tensor([[0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.],
         [0.]]),
 'pos_grad': None,
 'cell_grad': None}