In [2]:
#define ecloud latent encoder
import argparse
import torch
from coati.models.io.coati import load_e3gnn_smiles_clip_e2e
from coati.models.regression.basic_due import basic_due
from coati.utils.chem import read_sdf, write_sdf, rm_radical, sa, qed, logp
from rdkit import Chem
import random
from coati.generative.molopt import gradient_opt
from coati.generative.coati_purifications import embed_smiles
from functools import partial
from torch.nn.functional import sigmoid
import torch.nn.functional as F
import numpy as np
from coati.generative.coati_purifications import force_decode_valid_batch, embed_smiles, force_decode_valid
import os.path as osp
from coati.optimize.scoring import ScoringFunction
from coati.optimize.mol_functions import qed_score, substructure_match_score, penalize_macrocycles, heavy_atom_count, penalized_logp_score
from coati.optimize.pso_optimizer import BasePSOptimizer
from coati.optimize.swarm import Swarm
from coati.optimize.rules.qsar_score import qsar_model


arg_parser = argparse.ArgumentParser(description='molecular optimization on the chemical space')
arg_parser.add_argument('--device', choices=['cuda:0', 'cpu'], \
    default='cuda:3',help='Device')
arg_parser.add_argument('--seed', type=int, default=2024) 
arg_parser.add_argument('--ecloudgen_ckpt', type=str, default = 'model_ckpts/ecloud_smiles_67.pkl')
arg_parser.add_argument('--noise', type=float, default=0.3)
args = arg_parser.parse_args([])

# model loading
DEVICE = torch.device(args.device)
DEVICE = 'cuda:3'
encoder, tokenizer = load_e3gnn_smiles_clip_e2e(
    freeze=True,
    device=DEVICE,
    # model parameters to load.
    doc_url=args.ecloudgen_ckpt,
)

#ecloud latent encoder
class PSO_format_model():
    def __init__(self, model, tokenizer, device):
        self.model = model
        self.tokenizer = tokenizer
        self.device = device

    def seq_to_emb(self, smiles):
        if isinstance(smiles, str):
            smi_emb = embed_smiles(smiles, self.model, self.tokenizer).to(DEVICE)
            return smi_emb
        else:
            emb_list = []
            for smi in smiles:
                smi_emb = embed_smiles(smi, self.model, self.tokenizer).to(DEVICE)
                emb_list.append(smi_emb)
            return torch.stack(emb_list).reshape(-1, 256)

    
    def emb_to_seq(self, embs):

        seq_list = []
        for emb in embs:
            seq = force_decode_valid_batch(emb, self.model, self.tokenizer)
            seq_list.append(seq)
        return seq_list

Loading model from model_ckpts/ecloud_smiles_67.pkl
Loading tokenizer mar from model_ckpts/ecloud_smiles_67.pkl
number of parameters: 12.64M
number of parameters Total: 2.44M xformer: 19.60M Total: 22.04M 
Freezing encoder
44882816 params frozen!


In [3]:
#encoder example
ecloud_latent = PSO_format_model(encoder, tokenizer, DEVICE)
init_smiles = "c1ccccc1"
init_emb = ecloud_latent.seq_to_emb([init_smiles, init_smiles])
print(init_emb)
print(init_emb.size())

tensor([[-2.0585e-01,  1.0529e-01, -3.5812e-01,  2.7124e-02,  2.7386e-01,
          8.6519e-02,  1.1036e-01, -1.0451e-02, -1.8966e-01, -2.5554e-01,
         -3.3033e-01,  5.9155e-02, -1.6967e-02,  5.2427e-02, -9.2881e-02,
          2.6020e-02,  1.4271e-01,  2.7295e-02, -8.7557e-02,  4.2929e-02,
         -2.3445e-03, -1.7832e-01,  2.4162e-01, -3.1961e-01, -1.9733e-01,
         -8.7914e-02, -1.8385e-01,  7.3459e-02, -5.1598e-01, -1.0449e-01,
          3.2603e-01,  1.1386e+00,  1.6684e-01, -1.1404e-01, -1.5266e-01,
         -2.2697e-01,  1.8601e-01, -2.7853e-02, -2.5724e-01,  1.4475e-01,
         -1.4781e-02, -7.8996e-02, -7.8704e-03,  9.3806e-02, -2.3397e-01,
          6.3077e-02, -1.4422e-01,  4.2236e-01,  9.7027e-03, -1.9467e-01,
          1.8115e-01,  2.4048e-01,  8.8585e-02, -1.6303e-01, -1.6620e-01,
          2.3039e-01, -5.4375e-02,  6.1500e-02, -1.3833e-01, -7.0604e-02,
          1.1823e-01,  1.4608e-01, -2.3003e-01,  1.3378e-01,  2.7105e-02,
          5.0183e-02,  6.9273e-02,  1.

In [40]:
#data loading
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import torch

# 自定义 Dataset 类
class PaddleDataset(Dataset):
    def __init__(self, path, label_name=None):
        # 读取 CSV 文件
        self.df = pd.read_csv(path)
        # 提取 smiles 列
        self.smiles = self.df['smiles'].values
        # 如果有标签列，则提取标签
        self.labels = None
        if label_name is not None:
            self.labels = self.df[label_name].values

    def __len__(self):
        # 数据集的大小
        return len(self.smiles)

    def __getitem__(self, idx):
        # 根据索引返回单个样本
        sample = {'smiles': self.smiles[idx]}
        if self.labels is not None:
            sample['label'] = torch.tensor(self.labels[idx], dtype=torch.float)
        return sample

# 创建 DataLoader 的函数
from sklearn.model_selection import train_test_split
def split_dataloader(csv_path, label_name=None, batch_size = 64, eval_batch_size = 64, shuffle=True, num_workers=4):
    
    dataset = PaddleDataset(csv_path, label_name)
    train, valid = train_test_split(dataset, random_state=42, test_size=0.2)
    trn_loader = DataLoader(train, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
    valid_loader = DataLoader(valid, batch_size=eval_batch_size, shuffle=shuffle, num_workers=num_workers)
    return trn_loader, valid_loader

def get_tst_dataloader(csv_path, label_name=None, batch_size = 64, eval_batch_size = 64, shuffle=True, num_workers=4):
    
    dataset = PaddleDataset(csv_path, label_name)
    dataloader = DataLoader(dataset, batch_size=eval_batch_size, shuffle=shuffle, num_workers=num_workers)
    return dataloader
# 使用示例
trian_csv_path = 'data/train.csv'
trn_loader, valid_loader = split_dataloader(trian_csv_path, label_name='label')
test_csv_path = 'data/test_nopoint.csv'
tst_loader = get_tst_dataloader(test_csv_path, label_name=None)


# 查看 DataLoader 的样本
for batch in trn_loader:
    print(batch['smiles'])
    print(batch['label'])
    break  # 这里只是查看第一个批次的数据


['C[C@@H]1[C@]2(C)[CH+]C[C@]13CC[C@](C)(CCC2)C3', 'CC(C)[C@@H]1C[C@H]1[C+]1[C@H]2C[C@@]2(C)CC[C@@H]1C', 'C=C1CC[C@H]([C@H](C)[C@H]2C[CH+]C[C@H](C)C2)C1', 'C=C1CC[C@H]2C[C@@H]1CCC[C+](C)C[C@@H]2C', 'C=C(C[C@]1(C)C[C@@H]1C)[C+]1CCC(C)CC1', 'C[CH+]CC1=C[C@H](C)[C@@]12CC[C@@H](C(C)C)C2', 'C[C@@H]1CC[C@@H]1[C+]1CC[C@H](C)[C@H]2CC[C@@]12C', 'C=C(C)[C@H]1CC[C@@]2(C1)[C+](C)CCC[C@@H]2C', 'C[C@H]1C[C@H](C)C12C[C@@H]1CC[C+]1CC2(C)C', 'CC(C)[C@]12C[C@@]3(C[CH+][C@@H]1C)[C@@H](C)CC[C@H]32', 'C[C+]1C[C@H](C2=C[C@H](C)CCC2)CC[C@@H]1C', 'C[C+]1C[C@]23[C@@H](C)C[C@H](C[C@H]2C1(C)C)[C@H]3C', 'C=C1[C+](CC(C)C)C[C@@H](C)[C@@H]2C[C@@H]2[C@H]1C', 'CC(C)=C1CCC[C+](C)[C@@H]2CC[C@H](C)[C@@H]12', 'C[C@@H]1[CH+][C@@H]2CC/C2=C\\CC(C)(C)CCC1', 'CC(C)=C[CH+]C/C1=C/CC[C@@H](C)CCC1', 'C[C@@H]1[C@@H]2C[C@@H]([C@@H]3[CH+]C(C)(C)CCC3)[C@H]1C2', 'CC1=C[C@H]2[C+](C(C)C)CC[C@@]2(C)CCC1', 'C=C[C+]1[C@H](C)CC[C@]1(C)CCC=C(C)C', 'CC(C)[C@@H]1CC[C@H](C)[C@@]12[CH+][C@]1(C[C@@H]1C)C2', 'C[C+]1CCCC(C)(C)CC[C@@]12C[C@H]1C[C@H]12

In [27]:
#encoder example with dataloader
ecloud_latent = PSO_format_model(encoder, tokenizer, DEVICE)
init_smiles = "c1ccccc1"
init_emb = ecloud_latent.seq_to_emb([init_smiles, init_smiles])
# print(init_emb)
# print(init_emb.size())
# for batch in trn_loader:
#     print(batch['smiles'])
#     print(batch['label'])
#     ecloud_latent = PSO_format_model(encoder, tokenizer, DEVICE)
#     init_smils = batch['smiles']
#     init_emb = ecloud_latent.seq_to_emb(init_smiles)
#     print(init_emb)
#     print(init_emb.size())
#     break  # 这里只是查看第一个批次的数据

for data in tst_loader:
    print(data)
    break

{'smiles': ['C[C@@H]1CC[C@H](C[C+]2CC[C@H](C)[C@@H]3C[C@H]23)C1', 'CC1=C[C@H]2C[C@@H]1C[C@H](C(C)C)[CH+]C[C@H]2C', 'C[C+]1CC[C@H]2[C@@H](C)CCCC[C@@]23CC[C@@H]13', 'C=C(C)[C@@H]1C=CC[C@@H]([C+](C)CC)C1(C)C', 'CC(C)=C[CH+][C@@H]1C/C=C/CC[C@@H](C)[C@@H]1C', 'C[C+](C)/C1=C/[C@@H]2[C@@H](C)C[C@H]2C[C@@H](C)CC1', 'C/C1=C/C[C+]2C[C@@H](C)[C@@H](CC1)CC[C@@H]2C', 'C=C1C[C@H](C)[C@@H]2C[C@@]2(C)[C@@H](C)CC[C+]1C', 'C/C=C\\CCCC1=C[C+](C)CCC[C@H]1C', 'C[C@H]1CC[C+]2CCC[C@]3(C)CC[C@@]2(C1)C3', 'C=C1C[CH+]C[C@@]12CCC[C@@H](C)CC[C@H]2C', 'C[C@H]1C[CH+][C@@H]2[C@H](CC1)[C@]21CCCC1(C)C', 'C[C+]1CC=C[C@](C)(C[C@H]2C[C@H](C)C2)CC1', 'C=C1CCC2(CC1)CC[C+](C)CC2(C)C', 'CC1(C)C[CH+][C@@](C)(CC=C2CC2)CCC1', 'C[C@H]1[CH+][C@H]2[C@@H](CC1)[C@H](C)[C@@H]1[C@@H](C)[C@H](C)[C@H]21', 'C[C@H]1CC[C@@H]2C[C@H]3[C@H](C)[CH+]C[C@@H]1[C@@H]3[C@H]2C', 'C[C@H]1CCC[C@@H]2C[C@@H]3[C+]2[C@@H](CC[C@@H]3C)C1', 'C[C@H]1C(C)(C)[C@@H]2[CH+]CC[C@@](C)(C2)C12CC2', 'CC(C)[C@H]1CC[C+]2[C@H](C1)[C@@H](C)C=C[C@@H]2C', 'C[C@H]1CCC[C@@H]2

In [41]:
#model module
import torch.nn as nn

class ecloud_regre(nn.Module):
    def __init__(self, hidden_dim, out_dim):
        super().__init__()
        self.encoder = PSO_format_model(encoder, tokenizer, DEVICE)
        self.pred = torch.nn.Sequential(
            torch.nn.Linear(hidden_dim, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dim, out_dim)
        )
        

    def forward(self, data):

        smiles_out = self.encoder.seq_to_emb(data['smiles']) # (n, 256)

        out = self.pred(smiles_out)

        return out

model = ecloud_regre(256, 1)
# for data in trn_loader:
#     model = model.to(DEVICE)
#     out = model(data)
#     print('out', out.size())
#     print('label', data['label'])
#     break
y_pred = torch.tensor([]).to(DEVICE)
for data in tst_loader:
    model = model.to(DEVICE)
    out = model(data)
    y_pred = torch.cat((y_pred, out))
    # print('out', out.size())

In [49]:
#train module
import time
from datetime import datetime
def train(model, train_loader, optimizer, device):

    model.train()

    loss_all = 0
    epsilon = 1e-8
    # t1 = time.time()
    criterion = torch.nn.L1Loss(reduction='none')
    total_y_pred = torch.tensor([]).to(device)
    total_y_true = torch.tensor([]).to(device)
    for data in train_loader:
        model = model.to(device)
        optimizer.zero_grad()
        y_pred = model(data)

        # pred_loss = F.mse_loss(y_pred, data.y, reduction='none')
        # loss = torch.sqrt(pred_loss + epsilon).mean()
        # print(data.y[:, args.target])
        # y = ((data.y[:, args.target]-mean)/std).view(-1, 1)
        y = data['label'].view(-1, 1).to(device)
        pred_loss = criterion(y_pred, y)
        loss = pred_loss.mean()
        total_y_pred = torch.cat((total_y_pred, y_pred))
        total_y_true = torch.cat((total_y_true, y))
        
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0)

        loss_all += loss.item() * y.size(0) #adverse standardize
        
        # err = y_pred * std + mean - data.y[:, args.target].view(-1, 1) # nx1 adverse standardize

        # total_error = err.abs().sum().item() # float

        # loss_all += total_error
        optimizer.step()
    rmse = torch.sqrt(F.mse_loss(total_y_true, total_y_pred))
    return loss_all / len(train_loader.dataset), rmse

def eval(model, tst_loader, device):
    model.train()

    loss_all = 0
    # t1 = time.time()
    criterion = torch.nn.L1Loss(reduction='none')
    total_y_pred = torch.tensor([]).to(device)
    total_y_true = torch.tensor([]).to(device)
    for data in tst_loader:
        model = model.to(DEVICE)
        y_pred = model(data).detach()
        y = data['label'].view(-1, 1).to(device)
        total_y_pred = torch.cat((total_y_pred, y_pred))
        total_y_true = torch.cat((total_y_true, y))
        pred_loss = criterion(y_pred, y)
        loss = pred_loss.mean()
        loss_all += loss.item() * y.size(0) #adverse standardize

    rmse = torch.sqrt(F.mse_loss(total_y_true, total_y_pred))
    
    return loss_all / len(tst_loader.dataset), rmse

n_epochs = 100
lr = 5e-4
optimizer = torch.optim.Adam(model.parameters(), lr = lr)
best_val_mae = float('inf')
best_val_rmse = float('inf')
best_epoch = 0
date = datetime.today().strftime('%Y-%m-%d-%H-%M-%S')
for epoch in range(n_epochs):
    t_s = time.time()
    trn_mae, trn_rmse = train(model, trn_loader, optimizer, DEVICE)
    val_mae, val_rmse = eval(model, valid_loader, DEVICE)
    t_e = time.time()
    if best_val_rmse >= val_rmse:
        best_val_rmse = val_rmse
        best_epoch = epoch
    print(f'Epoch: {epoch:03d}, LR: {lr:5f}, Train MAE: {trn_mae:.7f}, Train RMSE: {trn_rmse:.7f}, Valid MAE: {val_mae:.7f}, Valid RMSE: {val_rmse:.7f}, Best Valid RMSE: {best_val_rmse:.7f}, Best Epoch: {best_epoch:03d}, Time:{(t_e-t_s)/60:.3f} min')
# 保存模型的参数
torch.save(model.state_dict(), f"model_record/ecloud_reg_paddle_{date}.pth")

Epoch: 000, LR: 0.000500, Train MAE: 9.7103853, Train RMSE: 11.5048351, Valid MAE: 9.8144123, Valid RMSE: 11.5938349, Best Valid RMSE: 11.5938349, Best Epoch: 000, Time:0.550 min
Epoch: 001, LR: 0.000500, Train MAE: 9.4157882, Train RMSE: 11.3306274, Valid MAE: 9.9230983, Valid RMSE: 11.7602301, Best Valid RMSE: 11.5938349, Best Epoch: 000, Time:0.557 min
Epoch: 002, LR: 0.000500, Train MAE: 9.3884908, Train RMSE: 11.2604036, Valid MAE: 9.7890697, Valid RMSE: 11.5691833, Best Valid RMSE: 11.5691833, Best Epoch: 002, Time:0.529 min
Epoch: 003, LR: 0.000500, Train MAE: 9.3486824, Train RMSE: 11.2022820, Valid MAE: 9.7755542, Valid RMSE: 11.5531330, Best Valid RMSE: 11.5531330, Best Epoch: 003, Time:0.540 min
Epoch: 004, LR: 0.000500, Train MAE: 9.3277785, Train RMSE: 11.1984711, Valid MAE: 9.7768602, Valid RMSE: 11.5643015, Best Valid RMSE: 11.5531330, Best Epoch: 003, Time:0.536 min
Epoch: 005, LR: 0.000500, Train MAE: 9.3010987, Train RMSE: 11.1538038, Valid MAE: 9.7498393, Valid RMSE:

In [50]:
#get test results
import pandas as pd
from datetime import datetime

model = ecloud_regre(256, 1)  # 先定义模型结构
model.load_state_dict(torch.load("model_record/ecloud_reg_paddle_2024-09-11-18-07-16.pth"))   # need to complete with a detailed file name
model.eval()
model.to(DEVICE)  # 移动模型到设备（例如 GPU）
y_pred = torch.tensor([]).to(DEVICE)
date = datetime.today().strftime('%Y-%m-%d-%H-%M-%S')

for data in tst_loader:
    output = model(data)
    y_pred = torch.cat((y_pred, output))
    
test_df = pd.read_csv('data/test.csv')
test_df['pred'] = y_pred.detach().cpu().numpy()
test_df.to_csv(f'paddle_results/result{date}.csv', index=False)


regression via geometric information

1.data process

In [1]:
#collate function
import torch
from torch_geometric.data import Data, Batch
import numpy as np

# class DownstreamCollateFn(object):
#     def __init__(self, task_type='regr', is_inference=True):
#         self.atom_names = ["atomic_num", "formal_charge", "degree", "chiral_tag", "total_numHs", "is_aromatic", "hybridization"]
#         self.bond_names = ["bond_dir", "bond_type", "is_in_ring"]
#         self.bond_float_names = ["bond_length"]
#         self.bond_angle_float_names = ["bond_angle"]
#         self.task_type = task_type
#         self.is_inference = is_inference

#     def _flat_shapes(self, d):
#         for name in d:
#             d[name] = d[name].reshape([-1])

#     def __call__(self, data_list):
#         atom_bond_graph_list = []
#         bond_angle_graph_list = []
#         compound_class_list = []

#         for data in data_list:
#             compound_class_list.append(data['Label'])
#             data = data['Graph']
#             # Atom bond graph construction
#             node_features = torch.tensor([data[name].reshape(-1, 1) for name in self.atom_names], dtype=torch.float).squeeze(-1)
#             edge_index = torch.tensor(data['edges'], dtype=torch.long).t().contiguous()
#             edge_features = torch.tensor([data[name].reshape(-1, 1) for name in self.bond_names + self.bond_float_names], dtype=torch.float).squeeze(-1)

#             atom_bond_graph = Data(x=node_features, edge_index=edge_index, edge_attr=edge_features)

#             # Bond angle graph construction
#             bond_angle_edges = torch.tensor(data['BondAngleGraph_edges'], dtype=torch.long).t().contiguous()
#             bond_angle_features = torch.tensor([data[name].reshape(-1, 1) for name in self.bond_angle_float_names], dtype=torch.float).squeeze(-1)

#             bond_angle_graph = Data(edge_index=bond_angle_edges, edge_attr=bond_angle_features)

#             atom_bond_graph_list.append(atom_bond_graph)
#             bond_angle_graph_list.append(bond_angle_graph)

#         # Batch the graphs
#         atom_bond_graph = Batch.from_data_list(atom_bond_graph_list)
#         bond_angle_graph = Batch.from_data_list(bond_angle_graph_list)

#         return atom_bond_graph, bond_angle_graph, torch.tensor(compound_class_list, dtype=torch.float32)

import torch
import numpy as np
from torch_geometric.data import Data, Batch

class DownstreamCollateFn(object):
    def __init__(self, task_type='regr', is_inference=True):
        atom_names = ["atomic_num", "formal_charge", "degree", "chiral_tag", "total_numHs", "is_aromatic",
                      "hybridization"]
        bond_names = ["bond_dir", "bond_type", "is_in_ring"]
        bond_float_names = ["bond_length"]
        bond_angle_float_names = ["bond_angle"]

        self.atom_names = atom_names
        self.bond_names = bond_names
        self.bond_float_names = bond_float_names
        self.bond_angle_float_names = bond_angle_float_names
        self.task_type = task_type
        self.is_inference = is_inference

    def _flat_shapes(self, d):
        for name in d:
            d[name] = d[name].reshape([-1])

    def __call__(self, data_list):
        atom_bond_graph_list = []
        bond_angle_graph_list = []
        compound_class_list = []

        for data in data_list:
            compound_class_list.append(data['Label'])
            graph_data = data['Graph']

            # Construct atom-bond graph
            num_nodes = len(graph_data[self.atom_names[0]])
            edge_index = torch.tensor(graph_data['edges'], dtype=torch.long).t().contiguous()  # edges in PyG are transposed
            node_features = {name: torch.tensor(graph_data[name].reshape([-1, 1]), dtype=torch.float) for name in self.atom_names}
            edge_features = {name: torch.tensor(graph_data[name].reshape([-1, 1]), dtype=torch.float) for name in self.bond_names + self.bond_float_names}

            # Create PyG Data object for atom-bond graph
            atom_bond_graph = Data(
                x=torch.cat([node_features[name] for name in self.atom_names], dim=-1),  # Combine node features
                edge_index=edge_index,
                edge_attr=torch.cat([edge_features[name] for name in self.bond_names + self.bond_float_names], dim=-1)  # Combine edge features
            )

            # Construct bond-angle graph
            bond_angle_edges = torch.tensor(graph_data['BondAngleGraph_edges'], dtype=torch.long).t().contiguous()
            bond_angle_features = {name: torch.tensor(graph_data[name].reshape([-1, 1]), dtype=torch.float) for name in self.bond_angle_float_names}

            # Create PyG Data object for bond-angle graph
            bond_angle_graph = Data(
                x=None,  # No node features for bond-angle graph
                edge_index=bond_angle_edges,
                edge_attr=torch.cat([bond_angle_features[name] for name in self.bond_angle_float_names], dim=-1)  # Combine edge features
            )

            atom_bond_graph_list.append(atom_bond_graph)
            bond_angle_graph_list.append(bond_angle_graph)

        # Use PyG's Batch to combine graphs into a batch
        atom_bond_graph_batch = Batch.from_data_list(atom_bond_graph_list)
        bond_angle_graph_batch = Batch.from_data_list(bond_angle_graph_list)

        # Return batch of atom-bond graphs, bond-angle graphs, and labels
        return atom_bond_graph_batch, bond_angle_graph_batch, torch.tensor(compound_class_list, dtype=torch.float32)


class ContrastiveLearningCollateFn(object):
    def __init__(self, task_type='regr', is_inference=True):
        self.atom_names = ["atomic_num", "formal_charge", "degree", "chiral_tag", "total_numHs", "is_aromatic", "hybridization"]
        self.bond_names = ["bond_dir", "bond_type", "is_in_ring"]
        self.bond_float_names = ["bond_length"]
        self.bond_angle_float_names = ["bond_angle"]
        self.task_type = task_type
        self.is_inference = is_inference

    def _flat_shapes(self, d):
        for name in d:
            d[name] = d[name].reshape([-1])

    def data_to_gs(self, data):
        node_features = torch.tensor([data[name].reshape(-1, 1) for name in self.atom_names], dtype=torch.float).squeeze(-1)
        edge_index = torch.tensor(data['edges'], dtype=torch.long).t().contiguous()
        edge_features = torch.tensor([data[name].reshape(-1, 1) for name in self.bond_names + self.bond_float_names], dtype=torch.float).squeeze(-1)

        atom_bond_graph = Data(x=node_features, edge_index=edge_index, edge_attr=edge_features)

        bond_angle_edges = torch.tensor(data['BondAngleGraph_edges'], dtype=torch.long).t().contiguous()
        bond_angle_features = torch.tensor([data[name].reshape(-1, 1) for name in self.bond_angle_float_names], dtype=torch.float).squeeze(-1)

        bond_angle_graph = Data(edge_index=bond_angle_edges, edge_attr=bond_angle_features)

        return atom_bond_graph, bond_angle_graph

    def __call__(self, data_list):
        anchor_atom_bond_graph_list = []
        anchor_bond_angle_graph_list = []
        positive_atom_bond_graph_list = []
        positive_bond_angle_graph_list = []
        negative_atom_bond_graph_list = []
        negative_bond_angle_graph_list = []
        compound_class_list = []

        for data in data_list:
            anchor = data[0]
            positive = data[1]
            negative = data[2]
            compound_class_list.append(anchor['Label'])

            anchor_ab_g, anchor_ba_g = self.data_to_gs(anchor['Graph'])
            positive_ab_g, positive_ba_g = self.data_to_gs(positive['Graph'])
            negative_ab_g, negative_ba_g = self.data_to_gs(negative['Graph'])

            anchor_atom_bond_graph_list.append(anchor_ab_g)
            anchor_bond_angle_graph_list.append(anchor_ba_g)
            positive_atom_bond_graph_list.append(positive_ab_g)
            positive_bond_angle_graph_list.append(positive_ba_g)
            negative_atom_bond_graph_list.append(negative_ab_g)
            negative_bond_angle_graph_list.append(negative_ba_g)

        # Batch the graphs
        anchor_atom_bond_graph = Batch.from_data_list(anchor_atom_bond_graph_list)
        anchor_bond_angle_graph = Batch.from_data_list(anchor_bond_angle_graph_list)
        positive_atom_bond_graph = Batch.from_data_list(positive_atom_bond_graph_list)
        positive_bond_angle_graph = Batch.from_data_list(positive_bond_angle_graph_list)
        negative_atom_bond_graph = Batch.from_data_list(negative_atom_bond_graph_list)
        negative_bond_angle_graph = Batch.from_data_list(negative_bond_angle_graph_list)

        return (anchor_atom_bond_graph, anchor_bond_angle_graph,
                positive_atom_bond_graph, positive_bond_angle_graph,
                negative_atom_bond_graph, negative_bond_angle_graph,
                torch.tensor(compound_class_list, dtype=torch.float32))


In [69]:
#dataset
import os
from os.path import join, exists

import numpy as np

# from pgl.utils.data import Dataloader
from torch_geometric.loader import DataLoader  # pyg 的数据加载器
# from pahelix.utils.data_utils import save_data_list_to_npz, load_npz_to_data_list
# from pahelix.utils.basic_utils import mp_pool_map
from torch.multiprocessing import Pool  # 用于并行化操作
from rdkit import Chem
from rdkit.Chem import AllChem
from pyscf import gto, scf, tools
import numpy as np
from scipy.ndimage import zoom
import numpy as np
import torch
import h5py
from coati.models.encoding.tokenizers.trie_tokenizer import TrieTokenizer
from coati.models.encoding.tokenizers import get_vocab
__all__ = ['PaddleDataset3d']


class PaddleDataset3d(object):
    def __init__(self, 
            data_list=None,
            npz_data_path=None,
            npz_data_files=None):

        super(PaddleDataset3d, self).__init__()
        self.data_list = data_list
        self.npz_data_path = npz_data_path
        self.npz_data_files = npz_data_files

        if not npz_data_path is None:
            self.data_list = self._load_npz_data_path(npz_data_path)

        if not npz_data_files is None:
            self.data_list = self._load_npz_data_files(npz_data_files)

    def _load_npz_data_path(self, data_path):
        data_list = []
        files = [f for f in os.listdir(data_path) if f.endswith('.npz')]
        files = sorted(files)
        for f in files:
            # data_list += load_npz_to_data_list(join(data_path, f))
            data_list += np.ndarray.tolist(join(data_path, f))
        return data_list

    def _load_npz_data_files(self, data_files):
        data_list = []
        for f in data_files:
            # data_list += load_npz_to_data_list(f)
            data_list += np.ndarray.tolist(f)
        return data_list

    def _save_npz_data(self, data_list, data_path, max_num_per_file=10000):
        if not exists(data_path):
            os.makedirs(data_path)
        n = len(data_list)
        for i in range(int((n - 1) / max_num_per_file) + 1):
            filename = 'part-%06d.npz' % i
            sub_data_list = self.data_list[i * max_num_per_file: (i + 1) * max_num_per_file]
            np.savez(sub_data_list, join(data_path, filename))

    def save_data(self, data_path):
        """
        Save the ``data_list`` to the disk specified by ``data_path`` with npz format.
        After that, call `InMemoryDataset(data_path)` to reload the ``data_list``.

        Args:
            data_path(str): the path to the cached npz path.
        """
        self._save_npz_data(self.data_list, data_path)

    def __getitem__(self, key):
        if isinstance(key, slice):
            start, stop, step = key.indices(len(self))
            dataset = PaddleDataset3d(
                    data_list=[self[i] for i in range(start, stop, step)])
            return dataset
        elif isinstance(key, int) or \
                isinstance(key, np.int64) or \
                isinstance(key, np.int32):
            return self.data_list[key]
        elif isinstance(key, list):
            dataset = PaddleDataset3d(
                    data_list=[self[i] for i in key])
            return dataset
        else:
            raise TypeError('Invalid argument type: %s of %s' % (type(key), key))

    def __len__(self):
        return len(self.data_list)

    def transform(self, transform_fn, num_workers=1, drop_none=False):
        """
        Inplace apply `transform_fn` on the `data_list` with multiprocess.
        """
        with Pool(num_workers) as pool:
            data_list = pool.map(transform_fn, self.data_list)
        if drop_none:
            self.data_list = [data for data in data_list if data is not None]
        else:
            self.data_list = data_list
            
    def read_cube_file(self, file_path):
        with open(file_path, 'r') as f:
            lines = f.readlines()

        atom_line = lines[2].split()     # 跳过前两行（注释行）
        num_atoms = int(atom_line[0])    # 获取原子数和原点坐标

        grid_info = [list(map(float, lines[i].split())) for i in range(3, 6)]     # 第三行到第五行分别是网格信息
        grid_shape = [int(abs(info[0])) for info in grid_info]  # 获取网格维度
        
        origin = np.array(grid_info)[:, 1:4]     # 网格大小和原点
        
        # atom_info = [list(map(float, lines[i].split())) for i in range(6, 6 + num_atoms)]     # 原子信息 (读取接下来 num_atoms 行的数据)
        atom_info = []
        mol_coords = []
        mol_atomic = []
        for i in range(6, 6 + num_atoms):
            line_data = list(map(float, lines[i].split()))
            atomic_number = int(line_data[0])  # 原子序号
            coordinates = line_data[2:5]  # x, y, z 坐标
            mol_coords.append(coordinates)
            mol_atomic.append(atomic_number)
        atom_info.append({
            'atomic_number': mol_atomic,
            'coordinates': mol_coords
        })
        density_data = []
        for line in lines[6 + num_atoms:]:
            density_data.extend(map(float, line.split()))
        
        density_data = np.array(density_data).reshape(grid_shape)
        
        return density_data, atom_info, origin, grid_shape

    def sml2ecloud(self, smiles):
        '''
        smiles: list of smile stringa
        return: h5 file, 
        '''
        eclouds = torch.tensor([])
        atomic_number = torch.tensor([])
        coords = torch.tensor([])
        augmented_tokens = torch.tensor([])
        tokenizer = TrieTokenizer(n_seq=128, **get_vocab('mar'))
        for smile in smiles:
            mol = Chem.MolFromSmiles(smile)

            # mol to points
            mol = Chem.AddHs(mol)  # 添加氢原子
            AllChem.EmbedMolecule(mol)  # 生成3D坐标
            AllChem.UFFOptimizeMolecule(mol)  # 用UFF力场优化
            
            conf = mol.GetConformer()
            xyz = ""
            for i, atom in enumerate(mol.GetAtoms()):
                pos = conf.GetAtomPosition(i)
                xyz += f"{atom.GetSymbol()} {pos.x} {pos.y} {pos.z}\n"

            mol = gto.M(atom=xyz, basis="sto-3g")  # 定义分子并选择基组
            mf = scf.RHF(mol)  # Hartree-Fock计算
            mf.kernel()  # 运行计算

            # 生成电子密度
            tools.cubegen.density(mol, f'examples/eclouds_{smile}.cube', mf.make_rdm1())
            # # 添加 batch 维度和 channel 维度，形状为 (1, 1, x, y, z)
            # density_tensor = density_tensor.unsqueeze(0).unsqueeze(0)

            # print(density_tensor.shape)  # 输出形状为 (1, 1, x, y, z)

            # print("电子密度网格形状：", density_data.shape)
            # print("原子信息：", atom_info)
            # print("网格原点：", origin)

            ecloud_density, atom_info, origin, grid_shape = self.read_cube_file(f'paddle_data/ecloud/eclouds_{smile}.cube')
            print('atom info', atom_info)
            n = ecloud_density.shape[0]  # get size of raw ecloud
            target_shape = (32, 32, 32) # 使用 scipy 的 zoom 函数将电子密度插值到 (32, 32, 32)
            ecloud = zoom(ecloud_density, (target_shape[0] / n, target_shape[1] / n, target_shape[2] / n))
            
            ecloud = torch.tensor(ecloud, dtype=torch.double).unsqueeze(0)
            mol_atomic = torch.tensor(atom_info[0]['atomic_number']).unsqueeze(0)
            mol_coords = torch.tensor(atom_info[0]['coordinates']).unsqueeze(0)
            augmented_token = torch.tensor(tokenizer.tokenize_text("[CLIP][UNK][SMILES][SUFFIX][MIDDLE]" + smile + "[STOP]", pad=True)).unsqueeze(0)
            
            eclouds = torch.cat((eclouds, ecloud), dim=0)
            atomic_number = torch.cat((atomic_number, mol_atomic), dim=0)
            coords = torch.cat((coords, mol_coords), dim=0)
            augmented_tokens = torch.cat((augmented_tokens, augmented_token))
        # 假设 density_32 是 (32, 32, 32) 的 3D 电子云数据
        with h5py.File(f'paddle_data/ecloud/eclouds_{len(smiles)}.h5', 'w') as f:
            f.create_dataset('ecloud', data=eclouds)
            f.create_dataset('atomic_number', data=atomic_number)
            f.create_dataset('coords', data=coords)
            f.create_dataset('smiles', data=smiles)
            f.create_dataset('augmented_tokens', data=augmented_tokens)


    def get_data_loader(self, batch_size, num_workers=1, shuffle=False, collate_fn=None):
        data_list = []
        i = 0
        for data in self.data_list:
            atomic_number = torch.tensor(data['Graph']['atomic_num'], dtype=torch.long)
            pos = torch.tensor(data['Graph']['atom_pos'], dtype=torch.float)
                
            edge_index = torch.tensor(data['Graph']['edges'], dtype=torch.long).view(2, -1)
            edge_type = torch.tensor(data['Graph']['bond_type'], dtype=torch.long).view(-1, 1)
            edge_length = torch.tensor(data['Graph']['bond_length'], dtype=torch.float).view(-1, 1)
            
            edge_angle = torch.tensor(data['Graph']['bond_angle'], dtype=torch.float).view(-1, 1)
            edge_angle_index = torch.tensor(data['Graph']['BondAngleGraph_edges'], dtype=torch.long).view(2, -1)
            
            label = torch.tensor(data['Label'], dtype=torch.float).view(-1, 1)

            name = data['Smiles']
            
            per_data = Data(atomic_number=atomic_number, pos=pos, edge_index=edge_index, 
                edge_type=edge_type, edge_length=edge_length,y=label, name=name, index=i, 
                edge_angle_index=edge_angle_index, edge_angle=edge_angle)
            data_list.append(per_data)
            i += 1
        return DataLoader(data_list, 
                batch_size=batch_size, 
                num_workers=num_workers, 
                shuffle=shuffle,
                collate_fn=collate_fn)

In [70]:
#dataloader
import pickle
from sklearn.model_selection import train_test_split


def get_data_loader(mode, batch_size=256):
    collate_fn = DownstreamCollateFn()
    if mode == 'train':
        data_list = pickle.load(open("data/conformation/train.pkl", 'rb'))

        train, valid = train_test_split(data_list, random_state=42, test_size=0.2)
        # print('train', train[0], '\n valid', valid[0])
        train, valid = PaddleDataset3d(train), PaddleDataset3d(valid)
        # print('train', train[0], '\n valid', valid[0])

        print(f'len train is {len(train)}, len valid is {len(valid)}')

        train_dl = train.get_data_loader(batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
        valid_dl = valid.get_data_loader(batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

        return train_dl, valid_dl
    elif mode == 'test':
        data_list = pickle.load(open("data/conformation/test.pkl", 'rb'))

        print(f'len test is {len(data_list)}')

        test = PaddleDataset3d(data_list)
        test_dl = test.get_data_loader(batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
        return test_dl

In [73]:
train_data_loader, valid_data_loader = get_data_loader(mode='train', batch_size=2)   

len train is 1381, len valid is 346


In [75]:
for data in train_data_loader:
    print(data.name)
    break

['C[C@H]1[C+]2CC[C@@H]1[C@@]1(C)CC[C@@H](C2)C1(C)C', 'C[C@@H]1CC[C@H]2[C@@H](C1)[C@]1(C)CCC[CH+][C@H]2C1']


2.points-ecloud encoder