In [5]:
import pickle as pkl
import json
import numpy as np
from rdkit import Chem
from rdkit.Chem import AllChem
from tqdm import tqdm

def generate_3d_structure(smiles):
    """生成分子的三维构象"""
    molecule = Chem.MolFromSmiles(smiles)
    if molecule:
        success = AllChem.EmbedMolecule(molecule, randomSeed=42)
        if success != 0:
            return None
        AllChem.UFFOptimizeMolecule(molecule)
        conformer = molecule.GetConformer()
        coordinates = conformer.GetPositions()
        atom_types = [atom.GetSymbol() for atom in molecule.GetAtoms()]
        bond_types = []
        for bond in molecule.GetBonds():
            start_atom = bond.GetBeginAtomIdx()
            end_atom = bond.GetEndAtomIdx()
            btype = bond.GetBondTypeAsDouble()
            bond_types.append((start_atom, end_atom, btype))
        return {
            "coordinates": coordinates,
            "atom_types": atom_types,
            "bond_types": bond_types
        }
    return None

def random_rotation_matrix():
    """生成随机旋转矩阵（3D）"""
    theta = np.random.uniform(0, 2 * np.pi)
    phi = np.random.uniform(0, 2 * np.pi)
    z = np.random.uniform(0, 2 * np.pi)

    Rz = np.array([
        [np.cos(z), -np.sin(z), 0],
        [np.sin(z),  np.cos(z), 0],
        [0,          0,         1]
    ])
    Ry = np.array([
        [np.cos(phi), 0, np.sin(phi)],
        [0,           1, 0],
        [-np.sin(phi), 0, np.cos(phi)]
    ])
    Rx = np.array([
        [1, 0,          0],
        [0, np.cos(theta), -np.sin(theta)],
        [0, np.sin(theta),  np.cos(theta)]
    ])
    return Rz @ Ry @ Rx

def equivariant_transform(coords):
    """应用随机旋转和平移"""
    coords = np.array(coords)
    R = random_rotation_matrix()
    return coords @ R.T

# 主处理逻辑
results = []
json_path = '/data1/chenyuxuan/Project/MSMLM/data/traindata/pubchem_data.jsonl'

# 获取总行数（减去 header）
with open(json_path, 'r') as f:
    total_lines = sum(1 for _ in f) - 1  # 减去第一行 header

with open(json_path, 'r') as f:
    for i, line in tqdm(enumerate(f), total=total_lines + 1, desc="Processing molecules"):
        if i == 0:
            continue
        data = json.loads(line)
        smiles = data.get("smiles") or data.get("SMILES")
        embedding = data.get("embedding")
        if not smiles:
            continue
        mol_data = generate_3d_structure(smiles)
        if mol_data:
            for _ in range(10):  # 生成 10 个独立样本
                coords = equivariant_transform(mol_data["coordinates"])
                results.append({
                    "CID": data.get("CID"),
                    "IUPAC": data.get("IUPAC"),
                    "smiles": smiles,
                    "embedding": embedding,
                    "coordinates": coords,
                    "atom_types": mol_data["atom_types"],
                    "bond_types": mol_data["bond_types"],
                })

# 保存为 .pkl
with open("/data1/chenyuxuan/Project/MSMLM/data/traindata/pubchem_data.pkl", "wb") as f:
    pkl.dump(results, f)

print(f"共保存 {len(results)} 条数据，每个 SMILES 生成 10 个增强样本。")


Processing molecules:   0%|          | 0/4970 [00:00<?, ?it/s][01:58:50] Molecule does not have explicit Hs. Consider calling AddHs()
[01:58:50] Molecule does not have explicit Hs. Consider calling AddHs()
Processing molecules:   0%|          | 2/4970 [00:00<05:28, 15.13it/s][01:58:50] Molecule does not have explicit Hs. Consider calling AddHs()
[01:58:50] Molecule does not have explicit Hs. Consider calling AddHs()
[01:58:50] Molecule does not have explicit Hs. Consider calling AddHs()
[01:58:50] Molecule does not have explicit Hs. Consider calling AddHs()
[01:58:50] Molecule does not have explicit Hs. Consider calling AddHs()
[01:58:50] Molecule does not have explicit Hs. Consider calling AddHs()
[01:58:50] Molecule does not have explicit Hs. Consider calling AddHs()
[01:58:50] Molecule does not have explicit Hs. Consider calling AddHs()
[01:58:50] Molecule does not have explicit Hs. Consider calling AddHs()
[01:58:50] Molecule does not have explicit Hs. Consider calling AddHs()
[01:

共保存 49640 条数据，每个 SMILES 生成 10 个增强样本。


In [3]:
import pickle as pkl
import json
with open('/data1/chenyuxuan/Project/MSMLM/data/traindata/pubchem_data.pkl', 'rb') as f:
    data = pkl.load(f)
count = {}
print(data[0].keys())
for d in data:
    for a in d["atom_types"]:
        if a not in count:
            count[a] = 0
        count[a] += 1
for k, v in count.items():
    print(f"{k}: {v}")
with open('/data1/chenyuxuan/Project/MSMLM/data/traindata/pubchem_atom_types.json', 'w') as f:
    json.dump(count, f, indent=4)

dict_keys(['CID', 'IUPAC', 'smiles', 'embedding', 'coordinates', 'atom_types', 'bond_types'])
C: 823160
O: 223360
N: 99880
H: 1103800
P: 9210
Cl: 6730
S: 11770
As: 80
Br: 1440
Ca: 20
Co: 60
Se: 100
Hg: 70
I: 1330
K: 50
Mg: 10
Na: 640
Ni: 30
Fe: 40
W: 10
B: 150
F: 4500
Si: 30
Te: 20
Al: 10
Cd: 10
Pt: 50
Rh: 10
Li: 10
Be: 50
Mn: 20
Pr: 10
Ag: 10


In [4]:
import numpy as np
from rdkit import Chem
from rdkit.Chem import rdchem
import matplotlib.colors as mcolors

# 替换为你给出的统计频次
atom_hist_input = count
# 1. 排序
sorted_atoms = sorted(atom_hist_input.items(), key=lambda x: -x[1])
atom_decoder = [k for k, v in sorted_atoms]
atom_encoder = {atom: i for i, atom in enumerate(atom_decoder)}
atom_hist = {atom: freq for atom, freq in sorted_atoms}
print("Sorted atom frequencies:", atom_hist)

Sorted atom frequencies: {'H': 1103800, 'C': 823160, 'O': 223360, 'N': 99880, 'S': 11770, 'P': 9210, 'Cl': 6730, 'F': 4500, 'Br': 1440, 'I': 1330, 'Na': 640, 'B': 150, 'Se': 100, 'As': 80, 'Hg': 70, 'Co': 60, 'K': 50, 'Pt': 50, 'Be': 50, 'Fe': 40, 'Ni': 30, 'Si': 30, 'Ca': 20, 'Te': 20, 'Mn': 20, 'Mg': 10, 'W': 10, 'Al': 10, 'Cd': 10, 'Rh': 10, 'Li': 10, 'Pr': 10, 'Ag': 10}


In [None]:
# 2. 原子颜色（使用 matplotlib 颜色表）
colors = list(mcolors.TABLEAU_COLORS.values())
while len(colors) < len(atom_decoder):
    colors += list(mcolors.CSS4_COLORS.values())
colors_dic = colors[:len(atom_decoder)]
print("Colors assigned to atoms:", colors_dic)

In [None]:

# 3. 原子半径
pt = Chem.GetPeriodicTable()
radius_dic = [pt.GetRvdw(atom) if pt.GetRvdw(atom) else 0.3 for atom in atom_decoder]

# 4. 键长矩阵（单位：皮米）
def generate_bond_matrix(radius, scale):
    n = len(radius)
    mat = np.zeros((n, n))
    for i in range(n):
        for j in range(n):
            mat[i, j] = (radius[i] + radius[j]) * scale * 100  # 乘以100转为皮米
    return mat.tolist()

bonds1 = generate_bond_matrix(radius_dic, 1.1)
bonds2 = generate_bond_matrix(radius_dic, 1.0)
bonds3 = generate_bond_matrix(radius_dic, 0.9)
lennard_jones_rm = generate_bond_matrix(radius_dic, 1.0)

# 5. 构建最终配置字典
dataset_params = {}
dataset_params['your_dataset_name'] = {
    'atom_encoder': atom_encoder,
    'atom_decoder': atom_decoder,
    'atom_hist': atom_hist,
    'colors_dic': colors_dic,
    'radius_dic': radius_dic,
    'bonds1': bonds1,
    'bonds2': bonds2,
    'bonds3': bonds3,
    'lennard_jones_rm': lennard_jones_rm
}

with open('/data1/chenyuxuan/Project/MSMLM/data/traindata/chatmol/chatmol_dataset_params.json', 'w') as f:
    json.dump(dataset_params, f, indent=4)
    

In [None]:
import pickle as pkl
with open("/data1/lvchangwei/GNN/orgin_data/ZINC/processed/AAAA_processed.pkl", "rb") as f:
    data = pkl.load(f)
print(data[0].keys())
print(len(data))
for key, value in data[0].items():
    print(key, value)

dict_keys(['smiles', 'zinc_id', 'atom_count', 'atom_features', 'edge_index', 'edge_features', 'mol_molecular_weight', 'mol_logp', 'mol_num_h_acceptors', 'mol_num_h_donors', 'mol_tpsa', 'mol_num_rotatable_bonds', 'atom_coords'])
3730
smiles CO[C@H]1OC[C@@H](O)[C@H](O)[C@H]1O
zinc_id 4371221
atom_count 11
atom_features [{'atom_atomic_num': 6, 'atom_formal_charge': 0, 'atom_is_aromatic': 0, 'atom_is_in_ring': 0, 'atom_total_num_hs': 3, 'atom_hybridization': 'SP3'}, {'atom_atomic_num': 8, 'atom_formal_charge': 0, 'atom_is_aromatic': 0, 'atom_is_in_ring': 0, 'atom_total_num_hs': 0, 'atom_hybridization': 'SP3'}, {'atom_atomic_num': 6, 'atom_formal_charge': 0, 'atom_is_aromatic': 0, 'atom_is_in_ring': 1, 'atom_total_num_hs': 1, 'atom_hybridization': 'SP3'}, {'atom_atomic_num': 8, 'atom_formal_charge': 0, 'atom_is_aromatic': 0, 'atom_is_in_ring': 1, 'atom_total_num_hs': 0, 'atom_hybridization': 'SP3'}, {'atom_atomic_num': 6, 'atom_formal_charge': 0, 'atom_is_aromatic': 0, 'atom_is_in_ring': 1,