# qcGEM

### Load package

In [1]:
%load_ext autoreload  
%autoreload 2  
%reload_ext autoreload

import os
os.chdir('/export/disk6/why/workbench/MERGE/GLI/0_qcGEM_Github_copy/run/')
import os
import sys
import argparse
import random
import time
import numpy as np
import pandas as pd
import copy
import torch
import torch.utils.data
from torch import nn, optim
from tqdm import tqdm
from torch_geometric.loader import DataLoader
from torch_geometric.nn import global_add_pool, aggr
import torch.nn.init as init  

import models
from dataset_pyg import qcGEM_Data, qcGEM_example
from cal_loss import Loss_GNE

import umap
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

from rdkit import Chem
from rdkit.Chem.Scaffolds import MurckoScaffold

import molplotly
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import plotly.express as px
import seaborn as sns

## Load model

In [None]:
import argparse

parser = argparse.ArgumentParser(description='Args of Representation Pre-train Model')

parser.add_argument('--encoder_method', type=str, default='qcGEM_Encoder', metavar='N',
                    help='Select the encoder model.')
parser.add_argument('--decoder_method', type=str, default='qcGEM_Decoder', metavar='N',
                    help='Select the decoder model.')

args = parser.parse_args([
                        '--encoder_method', 'qcGEM_Encoder',
                        '--decoder_method', 'qcGEM_Decoder',
                        ])

args.encoder_layers = 16
args.decoder_layers = 0
args.heads = 8
args.global_head_dim = 32
args.node_head_dim = 32
args.edge_head_dim = 32
args.botnec_global_dim = 128
args.botnec_node_dim = 128
args.botnec_edge_dim = 128
args.gm_interact_time = 4
args.gm_layer_num = 3
args.gm_cutoff = 8.0
args.gm_output_dim = 12

args.init = 'None'
args.norm = 'layer'
args.remove_self_loop = False
args.global_mask_ratio = 1.0
args.mask_ratio = 0.3
args.replace_ratio = 0.3
args.remask_ratio = 0.0

args.device = 'cuda:7'
args.pretrained = True
args.pretrained_path = '../model/'
args.pretrained_model = 'qcGEM_ckpt.pt'

args.batch_size = 1
args.root_path = '../data/'
args.dataset = '20250101'
args.shuffle = False

In [None]:
def build_dataset(args):
    DataSet = qcGEM_example(root = args.root_path, dataset = args.dataset, split_mode = 'random', split_seed = 0)
    data_loader_train = DataLoader(DataSet.train, batch_size=args.batch_size, shuffle = args.shuffle)
    data_loader_valid = DataLoader(DataSet.val, batch_size=args.batch_size, shuffle = False)

    return args, data_loader_train, data_loader_valid

In [4]:
args, data_loader_train, data_loader_valid = build_dataset(args)

In [5]:
def build_model(args):

    global_dim, xyz_dim, node_dim, edge_dim = [200, 512, 512], 3, 80, 53

    model = models.qcGEM(input_global_dim= global_dim, global_head_dim = args.global_head_dim, botnec_global_dim = args.botnec_global_dim, 
                    input_node_dim = node_dim, node_head_dim = args.node_head_dim, BotNec_node_dim = args.botnec_node_dim, 
                    input_edge_dim = edge_dim, edge_head_dim = args.edge_head_dim, BotNec_edge_dim = args.botnec_edge_dim,
                    heads = args.heads, 
                    device = args.device, act_fn = nn.GELU(), norm = args.norm,
                    remove_self_loop = args.remove_self_loop,
                    global_mask_ratio = args.global_mask_ratio, mask_ratio = args.mask_ratio, replace_ratio = args.replace_ratio, remask_ratio = args.remask_ratio,
                    encoder_method = args.encoder_method, decoder_method = args.decoder_method,
                    encoder_layers = args.encoder_layers, decoder_layers = args.decoder_layers,
                    gm_cutoff = args.gm_cutoff, gm_output_dim = args.gm_output_dim, gm_interact_time = args.gm_interact_time, gm_layer_num = args.gm_layer_num)

    num_1 = sum(p.numel() for p in model.parameters())
    num_2 = sum(p.numel() for p in model.encoder.parameters())
    num_3 = sum(p.numel() for p in model.decoder.parameters())
    print(f' ==== The total num of parameters is {num_1}, Encoder is {num_2}, Decoder is {num_3}.')

    if args.pretrained == True:
        pre_trained_model_state = torch.load(f'{args.pretrained_path}/{args.pretrained_model}', map_location=torch.device(args.device))
        model.load_state_dict(pre_trained_model_state['model_state_dict'], strict=True)
    else:
        pass

    return args, model

In [6]:
args, model = build_model(args)

 ==== The total num of parameters is 37002671, Encoder is 35958209, Decoder is 1044321.


### Inference

In [13]:
for batch in data_loader_train:
    batch = batch.to(args.device)
    with torch.no_grad():
        output = model(batch)
    print(output[0]['CID_list'])
    break

['000781927']
