In [4]:
import random
import torch
import torch.nn as nn
from crystal_graph import CIFData
from crystal_graph import get_train_val_test_loader
from crystal_graph import collate_pool
from Res_GCN import Res_GCN
from test import test_all_dataset

In [6]:
# load materials CIF data for constructing crystal graph
cif_data = CIFData(r"C:\Users\PC\Downloads\Res-GCN\exp_data")
# git properties of crystal graph
structure, target, id = cif_data[0]
orig_atom_fea_len = structure[0].shape[-1]
nbr_fea_len = structure[1].shape[-1]

In [7]:
# create data loader for model prediction
data_loader, _ = get_train_val_test_loader(cif_data, collate_pool,
                                                     batch_size=8,
                                                     train_ratio=1,
                                                     val_ratio=0,
                                                     test_ratio=0,
                                                     train_size=None,
                                                     test_size=None,
                                                     val_size=None,
                                                     pin_memory=False,
                                                     num_workers=2)

In [8]:
# define the model hyperparameters
num_conv_layes = 2
num_res_layers = 2
num_hidden_layers = 2
# create the empty model with the defined hyperparameters
# the model is created with the same hyperparameters as the pre-trained Res-GCN model
model = Res_GCN(orig_atom_fea_len=orig_atom_fea_len, 
                 nbr_fea_len=nbr_fea_len, 
                 atom_fea_len=64, 
                 n_conv=num_conv_layes,
                 n_resconv=num_res_layers,
                 h_fea_len=128, 
                 n_h=num_hidden_layers,
                 classification=False)


In [None]:
# load pre-trained model parameters and load them into the model
para = torch.load(r"C:\Users\PC\Downloads\Res-GCN\pre-trained\pre-trained-model.pth")
model.load_state_dict(para)
model.eval()

Res_GCN(
  (embedding): Linear(in_features=92, out_features=64, bias=True)
  (convs): ModuleList(
    (0-1): 2 x GCB(
      (fc_full): Linear(in_features=169, out_features=128, bias=True)
      (sigmoid): Sigmoid()
      (softplus1): Softplus(beta=1, threshold=20)
      (bn1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bn2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (softplus2): Softplus(beta=1, threshold=20)
    )
  )
  (res_convs): ModuleList(
    (0-1): 2 x GRCB(
      (conv): GCB(
        (fc_full): Linear(in_features=169, out_features=128, bias=True)
        (sigmoid): Sigmoid()
        (softplus1): Softplus(beta=1, threshold=20)
        (bn1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (softplus2): Softplus(beta=1, threshold=20)
      )
      (bn): Batch

In [None]:
# define the device to be used for model prediction
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# predict the permittivity of the materials in the dataset
test_all_dataset(data_loader=data_loader, model=model, device=device)

([tensor([6.9358], device='cuda:0', grad_fn=<UnbindBackward0>),
  tensor([5.9969], device='cuda:0', grad_fn=<UnbindBackward0>),
  tensor([9.5293], device='cuda:0', grad_fn=<UnbindBackward0>),
  tensor([8.7968], device='cuda:0', grad_fn=<UnbindBackward0>),
  tensor([11.7157], device='cuda:0', grad_fn=<UnbindBackward0>),
  tensor([8.7007], device='cuda:0', grad_fn=<UnbindBackward0>),
  tensor([7.8104], device='cuda:0', grad_fn=<UnbindBackward0>),
  tensor([7.5217], device='cuda:0', grad_fn=<UnbindBackward0>)],
 [tensor([3.]),
  tensor([7.]),
  tensor([1.]),
  tensor([6.]),
  tensor([4.]),
  tensor([2.]),
  tensor([8.]),
  tensor([5.])],
 ['Ba3P4O13',
  'Al2Mo3O12',
  'LiCr(MoO4)2',
  'NaLa(MoO4)2',
  'Ba3V2O8',
  'LiAl(MoO4)2',
  'Sc2Mo3O12',
  'NaNd(MoO4)2'])