In [37]:
#1. 3d point encoder
import torch.nn as nn
import torch
from coati.models.encoding.e_gcl_sparse import e_gcl_sparse # encoder layer
from coati.common.periodic_table import XY_ONE_HOT_FULL

#settings
n_layer_e3gnn: int = 4
n_layer_xformer: int = 16
n_hidden_xformer: int = 256
n_hidden_e3nn: int = 256
msg_cutoff_e3nn: float = 4.0
n_embd_common: int = 256
n_head: int = 8
n_seq: int = 200
n_tok: int = 4
biases: bool = True
torch_emb: bool = False
residual: bool = False
norm_clips: bool = True
norm_embed: bool = False
token_mlp: bool = True  # Do we use a nonlinear MLP to convert HCLIP into a token.
use_point_encoder: bool = True  # if false, do not use a point encoder at all.
old_architecture: bool = False
device: torch.device = torch.device("cpu")
dtype: torch.dtype = torch.float
embed_dim = n_embd_common

#encoder function
class e3gnn_clip(torch.nn.Module):
    def __init__(
        self,
        in_node_nf: int = len(XY_ONE_HOT_FULL(1)),
        hidden_nf: int = 128,
        device: str = "cpu",
        act_fn: str = "SiLU",
        n_layers: int = 5,
        instance_norm: bool = True,
        message_cutoff: int = 5,
        dtype=torch.float,
        torch_emb: bool = False,
        residual: bool = False,
        dropout: float = 0.1,
    ):
        """
        The Welling research code is quadratic in batch size.
        and has no instancenorm. This fixes that.
        This also has no edge feature b/c bonds aren't real

        h_l => n_graph X n_node X n_hidden_features
        x_l => n_graph X n_node X n_dim
        e_ij => n_graph X n_node X n_node X n_edge_features

        Args:
            in_node_nf: number of input features for each node (atom)
            in_edge_nf: number of input featuers for each edge (bond)
            hidden_nf: dimension of the hidden representation (per atom)
            code_nf: dimension of a code conditioning the final aggregation. (optional)
            residual_feature: whether to include residual-like h0 in the node_model
        """
        super(e3gnn_clip, self).__init__()
        self.dtype = dtype
        self.hidden_nf = hidden_nf

        if not torch_emb:
            self.torch_emb = False
            self.in_node_nf = in_node_nf
            self.emb = None
        else:
            self.torch_emb = True
            self.in_node_nf = hidden_nf
            self.emb = nn.Embedding(84, self.hidden_nf, device=device, dtype=dtype)

        self.device = device
        self.n_layers = n_layers
        self.instance_norm = instance_norm
        self.message_cutoff = torch.tensor(message_cutoff, requires_grad=False)

        assert dropout >= 0.0 and dropout < 1.0
        self.dropout = dropout

        if act_fn == "SiLU":
            self.act_fn = nn.SiLU()
        elif act_fn == "GELU":
            self.act_fn = nn.GELU()
        else:
            raise Exception("Bad act_fn")

        ### Encoder
        if self.torch_emb:
            self.embedding = torch.nn.Identity()
        else:
            self.embedding = nn.Linear(self.in_node_nf, hidden_nf)

        if instance_norm:
            self.embedding_norm = torch.nn.InstanceNorm1d(hidden_nf)
        else:
            self.embedding_norm = torch.nn.Identity()

        self.node_dec = nn.Sequential(
            nn.Linear(self.hidden_nf, self.hidden_nf),
            self.act_fn,
            nn.Dropout(p=self.dropout) if self.dropout else nn.Identity(),
            nn.Linear(self.hidden_nf, self.hidden_nf),
        )

        for i in range(0, n_layers):
            self.add_module(
                "gcl_%d" % i,
                e_gcl_sparse(
                    self.hidden_nf,
                    act_fn=self.act_fn,
                    residual=residual,
                    attention=False,
                    instance_norm=instance_norm,
                    residual_nf=(in_node_nf if residual else 0),
                    dropout=dropout,
                    prop_coords=False,
                ),
            )

        self.to(self.device)

    def forward(self, atoms: torch.Tensor, coords: torch.Tensor) -> torch.Tensor:
        """
        atoms: batch X max_n_atom long tensor of atomic numbers.
        coords: node coordinates.
        """
        if self.torch_emb:
            assert (atoms > 84).sum().detach().item() == 0
            nodes = self.emb(atoms)
        else:
            with torch.no_grad():
                ans = atoms.tolist()
                nodes = torch.tensor(
                    [[XY_ONE_HOT_FULL(int(atom)) for atom in mol] for mol in ans],  #transform atomic number into atomic table one-hot(x = 18, y = 10)
                    dtype=torch.float32,
                    device=atoms.device,
                    requires_grad=False,
                )
        node_mask = (atoms > 0).to(atoms.device, torch.float)
        assert nodes.isfinite().all()
        assert coords.isfinite().all()
        assert node_mask.isfinite().all()
        # print('nodes', nodes)
        # bsize x n_atoms x hidden_nf
        # print('emb nodes', self.embedding(nodes).size())
        h = self.embedding_norm(self.embedding(nodes))
        for i in range(0, self.n_layers):
            h, _ = self._modules["gcl_%d" % i](h, coords, node_mask, h0=nodes)
        h = self.node_dec(h)
        h = h * node_mask.unsqueeze(-1)
        natoms = torch.maximum(node_mask.sum(-1), torch.ones_like(node_mask.sum(-1)))
        h = torch.sum(h, dim=1) / natoms.unsqueeze(-1)
        return h


#encoder
point_encoder = e3gnn_clip(
    device=device,
    dtype=dtype,
    hidden_nf=n_hidden_e3nn,
    message_cutoff=msg_cutoff_e3nn,
    dropout=0.0,
    torch_emb=torch_emb,
    residual=residual,
    n_layers=n_layer_e3gnn,
)

#norm
point_to_clip = nn.Sequential(
        nn.LayerNorm(point_encoder.hidden_nf),
        nn.Linear(point_encoder.hidden_nf, embed_dim),
    )

#full encoder
def encode_points(atoms: torch.Tensor, coords: torch.Tensor) -> torch.Tensor:
    """
    Embeds coordinates/atoms and projects into the latent space.

    If point encoder is not used, this returns a zero tensor.
    """
    if use_point_encoder:
        return point_to_clip(point_encoder(atoms, coords))
    else:
        return torch.zeros(atoms.shape[0], embed_dim).to(device)

In [9]:
import random
import numpy as np
b = 2
n = 2
d = 3
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
mol_coords = torch.rand((b, n, d))
mol_atom = torch.randint(1, 7, (b, n))
print(mol_coords, '\n', mol_atom)
h_points = encode_points(mol_atom, mol_coords)
print(h_points.size())

tensor([[[0.8823, 0.9150, 0.3829],
         [0.9593, 0.3904, 0.6009]],

        [[0.2566, 0.7936, 0.9408],
         [0.1332, 0.9346, 0.5936]]]) 
 tensor([[3, 1],
        [4, 5]])
ans [[3, 1], [4, 5]]
[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0]
[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]
[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0]
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0]
nodes tensor([[[0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
         [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,

In [38]:
#2. ecloud encoder
from typing import List, Union
import torch
from torch import nn
import torch.nn.functional as F

#settings
n_layer_e3gnn: int = 4
n_layer_xformer: int = 16
n_hidden_xformer: int = 256
n_hidden_e3nn: int = 256
msg_cutoff_e3nn: float = 4.0
n_embd_common: int = 256
n_head: int = 8
n_seq: int = 200
n_tok: int = 4
biases: bool = True
torch_emb: bool = False
residual: bool = False
norm_clips: bool = True
norm_embed: bool = False
token_mlp: bool = True  # Do we use a nonlinear MLP to convert HCLIP into a token.
use_point_encoder: bool = True  # if false, do not use a point encoder at all.
old_architecture: bool = False
device: torch.device = torch.device("cpu")
dtype: torch.dtype = torch.float
embed_dim = n_embd_common

#encoder layer
class Conv3DEncoder(nn.Module):
    def __init__(self, in_channels=1, kernel_size=3, stride=1, padding=1, dilation=1, d_model=256):
        super(Conv3DEncoder, self).__init__()
        self.d_model = d_model
        self.conv1 = nn.Conv3d(in_channels, d_model // 4, kernel_size, stride, padding, dilation)
        self.conv2 = nn.Conv3d(d_model // 4, d_model // 2, kernel_size, stride, padding, dilation)
        self.conv3 = nn.Conv3d(d_model// 2, d_model, kernel_size, stride, padding, dilation)
        self.pool = nn.MaxPool3d(kernel_size=2, stride=2)
        self.relu = nn.LeakyReLU()

    def forward(self, x):
        bz = x.size(0)
        x = x.unsqueeze(1)
        x = self.conv1(x)
        x = self.relu(x)
        x = self.pool(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.pool(x)
        x = self.conv3(x)
        x = self.relu(x)
        x = self.pool(x)
        x = x.view(bz, -1, self.d_model)
        sl = x.size(1)
        # prepare input for decoder
        x = x.transpose(0, 1)
        

        src_padding_mask = torch.zeros((bz, sl), dtype=torch.bool).to(x.device)

        # encoder_out = x, src_padding_mask
        # return encoder_out
        return x

#norm
eclouds_to_clip = nn.Sequential(
    nn.LayerNorm(64 * 768),
    nn.Linear(64 * 768, embed_dim),
)

#full ecloud_encoder
ecloud_encoder = Conv3DEncoder(d_model=768)
def encode_eclouds(eclouds: torch.Tensor) -> torch.Tensor:
    """
    Embeds eclouds into the latent space.

    """
    ecloud_emb = ecloud_encoder(eclouds).transpose(0, 1) # (b, s, 768)
    ecloud_emb = ecloud_emb.reshape(ecloud_emb.shape[0], -1) # (b, s*768)
    return eclouds_to_clip(ecloud_emb)


In [49]:
b = 2
x = 32
y = 32
z = 32
mol_ecloud = torch.rand((b, x, y, z))
print(mol_ecloud.dtype)
h_ecloud = encode_eclouds(mol_ecloud)
print(h_ecloud.size())

torch.float32
torch.Size([2, 128])


data process

In [39]:
#generate encoder example
from rdkit import Chem
from rdkit.Chem import AllChem
from pyscf import gto, scf, tools
import numpy as np
from scipy.ndimage import zoom
import numpy as np
import torch
import h5py
from coati.models.encoding.tokenizers.trie_tokenizer import TrieTokenizer
from coati.models.encoding.tokenizers import get_vocab

def read_cube_file(file_path):
    with open(file_path, 'r') as f:
        lines = f.readlines()

    atom_line = lines[2].split()     # 跳过前两行（注释行）
    num_atoms = int(atom_line[0])    # 获取原子数和原点坐标

    grid_info = [list(map(float, lines[i].split())) for i in range(3, 6)]     # 第三行到第五行分别是网格信息
    grid_shape = [int(abs(info[0])) for info in grid_info]  # 获取网格维度
    
    origin = np.array(grid_info)[:, 1:4]     # 网格大小和原点
    
    # atom_info = [list(map(float, lines[i].split())) for i in range(6, 6 + num_atoms)]     # 原子信息 (读取接下来 num_atoms 行的数据)
    atom_info = []
    mol_coords = []
    mol_atomic = []
    for i in range(6, 6 + num_atoms):
        line_data = list(map(float, lines[i].split()))
        atomic_number = int(line_data[0])  # 原子序号
        coordinates = line_data[2:5]  # x, y, z 坐标
        mol_coords.append(coordinates)
        mol_atomic.append(atomic_number)
    atom_info.append({
        'atomic_number': mol_atomic,
        'coordinates': mol_coords
    })
    density_data = []
    for line in lines[6 + num_atoms:]:
        density_data.extend(map(float, line.split()))
    
    density_data = np.array(density_data).reshape(grid_shape)
    
    return density_data, atom_info, origin, grid_shape

def sml2ecloud(smiles):
    '''
    smiles: list of smile stringa
    return: h5 file, 
    '''
    eclouds = torch.tensor([])
    atomic_number = torch.tensor([])
    coords = torch.tensor([])
    augmented_tokens = torch.tensor([])
    tokenizer = TrieTokenizer(n_seq=128, **get_vocab('mar'))
    for smile in smiles:
        mol = Chem.MolFromSmiles(smile)

        # mol to points
        mol = Chem.AddHs(mol)  # 添加氢原子
        AllChem.EmbedMolecule(mol)  # 生成3D坐标
        AllChem.UFFOptimizeMolecule(mol)  # 用UFF力场优化
        
        conf = mol.GetConformer()
        xyz = ""
        for i, atom in enumerate(mol.GetAtoms()):
            pos = conf.GetAtomPosition(i)
            xyz += f"{atom.GetSymbol()} {pos.x} {pos.y} {pos.z}\n"

        mol = gto.M(atom=xyz, basis="sto-3g")  # 定义分子并选择基组
        mf = scf.RHF(mol)  # Hartree-Fock计算
        mf.kernel()  # 运行计算

        # 生成电子密度
        tools.cubegen.density(mol, f'examples/eclouds_{smile}.cube', mf.make_rdm1())
        # # 添加 batch 维度和 channel 维度，形状为 (1, 1, x, y, z)
        # density_tensor = density_tensor.unsqueeze(0).unsqueeze(0)

        # print(density_tensor.shape)  # 输出形状为 (1, 1, x, y, z)

        # print("电子密度网格形状：", density_data.shape)
        # print("原子信息：", atom_info)
        # print("网格原点：", origin)

        ecloud_density, atom_info, origin, grid_shape = read_cube_file(f'examples/eclouds_{smile}.cube')
        print('atom info', atom_info)
        n = ecloud_density.shape[0]  # get size of raw ecloud
        target_shape = (32, 32, 32) # 使用 scipy 的 zoom 函数将电子密度插值到 (32, 32, 32)
        ecloud = zoom(ecloud_density, (target_shape[0] / n, target_shape[1] / n, target_shape[2] / n))
        
        ecloud = torch.tensor(ecloud, dtype=torch.double).unsqueeze(0)
        mol_atomic = torch.tensor(atom_info[0]['atomic_number']).unsqueeze(0)
        mol_coords = torch.tensor(atom_info[0]['coordinates']).unsqueeze(0)
        augmented_token = torch.tensor(tokenizer.tokenize_text("[CLIP][UNK][SMILES][SUFFIX][MIDDLE]" + smile + "[STOP]", pad=True)).unsqueeze(0)
        
        eclouds = torch.cat((eclouds, ecloud), dim=0)
        atomic_number = torch.cat((atomic_number, mol_atomic), dim=0)
        coords = torch.cat((coords, mol_coords), dim=0)
        augmented_tokens = torch.cat((augmented_tokens, augmented_token))
    # 假设 density_32 是 (32, 32, 32) 的 3D 电子云数据
    with h5py.File(f'examples/eclouds_{len(smiles)}.h5', 'w') as f:
        f.create_dataset('ecloud', data=eclouds)
        f.create_dataset('atomic_number', data=atomic_number)
        f.create_dataset('coords', data=coords)
        f.create_dataset('smiles', data=smiles)
        f.create_dataset('augmented_tokens', data=augmented_tokens)
smiles = "CCO"  # 乙醇
smiles = [smiles, smiles]
sml2ecloud(smiles)



converged SCF energy = -152.127991065501
atom info [{'atomic_number': [6, 6, 8, 1, 1, 1, 1, 1, 1], 'coordinates': [[-1.723572, 0.319103, 0.094505], [0.881941, -0.895013, -0.07507], [2.613345, 0.78721, -1.164387], [-2.380077, 0.880869, -1.81821], [-3.089854, -1.049592, 0.908722], [-1.649145, 2.017366, 1.325404], [1.527014, -1.4826, 1.840037], [0.761771, -2.611674, -1.275931], [3.058577, 2.03433, 0.164931]]}]
converged SCF energy = -152.128702196742
atom info [{'atomic_number': [6, 6, 8, 1, 1, 1, 1, 1, 1], 'coordinates': [[-1.890702, -0.202512, 0.39852], [0.927044, -0.368813, -0.142042], [1.886061, 2.052294, -0.608457], [-2.22926, 1.002338, 2.083537], [-2.891573, 0.622128, -1.251873], [-2.655278, -2.121252, 0.769064], [1.899849, -1.215916, 1.51865], [1.236723, -1.596591, -1.820879], [3.717136, 1.828325, -0.946519]]}]


In [40]:
#load data
with h5py.File(f'examples/eclouds_{len(smiles)}.h5', 'r') as f:
    eclouds = torch.tensor(f['ecloud'][:])
    atomic_number = torch.tensor(f['atomic_number'][:])
    coords = torch.tensor(f['coords'][:])
    augmented_tokens = torch.tensor(f['augmented_tokens'][:]).int()
print('ecloud', eclouds.size())
print('atomic number', atomic_number.size())
print('coords', coords.size())
print('augmented tokens', augmented_tokens.size())
h_points = encode_points(atomic_number, coords)
eclouds = eclouds.float()
h_ecloud = encode_eclouds(eclouds)
print(f'hidden molecular points: {h_points.size()}')
print(f'hidden molecular eclouds: {h_ecloud.size()}')


ecloud torch.Size([2, 32, 32, 32])
atomic number torch.Size([2, 9])
coords torch.Size([2, 9, 3])
augmented tokens torch.Size([2, 128])
hidden molecular points: torch.Size([2, 256])
hidden molecular eclouds: torch.Size([2, 256])


In [49]:
#e3gnn_eclouds_clip_e2e
import random
from typing import Dict, List, Any

import numpy as np
import torch
import torch.nn as nn
from torch import autocast
from torch.nn import functional as F

from coati.containers.rdkit_utils import disable_logger, permute_smiles
from coati.models.encoding.e3gnn_clip import e3gnn_clip
from coati.models.encoding.fill_in_middle import adj_mat_to_tokens
from coati.models.encoding.smiles_xformer import (
    RotarySmilesTransformer,
    SmilesTransformerConfig,
)
from coati.models.encoding.tokenizers.trie_tokenizer import TrieTokenizer
from coati.models.encoding.prefix_encoder import Conv3DEncoder
from rdkit import Chem
from coati.models.encoding.clip_e2e import clip_loss
class e3gnn_eclouds_clip_e2e(nn.Module):
    def __init__(
        self,
        n_layer_e3gnn: int = 4,
        n_layer_xformer: int = 16,
        n_hidden_xformer: int = 128,
        n_hidden_e3nn: int = 128,
        msg_cutoff_e3nn: float = 4.0,
        n_embd_common: int = 128,
        n_head: int = 8,
        n_seq: int = 200,
        n_tok: int = 4,
        biases: bool = True,
        torch_emb: bool = False,
        residual: bool = False,
        norm_clips: bool = True,
        norm_embed: bool = False,
        token_mlp: bool = True,  # Do we use a nonlinear MLP to convert HCLIP into a token.
        use_point_encoder: bool = True,  # if false, do not use a point encoder at all.
        old_architecture: bool = False,
        device: torch.device = torch.device("cpu"),
        dtype: torch.dtype = torch.float,
    ):
        super().__init__()
        self.embed_dim = n_embd_common
        self.point_encoder = e3gnn_clip(
            device=device,
            dtype=dtype,
            hidden_nf=n_hidden_e3nn,
            message_cutoff=msg_cutoff_e3nn,
            dropout=0.0,
            torch_emb=torch_emb,
            residual=residual,
            n_layers=n_layer_e3gnn,
        )
        kwargs = {
            "n_layer": n_layer_xformer,
            "n_embd": n_hidden_xformer,
            "n_head": n_head,
            "n_seq": n_seq,
            "n_tok": n_tok,
            "device": device,
            "dtype": dtype,
            "biases": biases,
            "norm_embed": norm_embed,
        }
        self.xformer_config = SmilesTransformerConfig(**kwargs)
        self.xformer = RotarySmilesTransformer(self.xformer_config)
        self.device = device
        self.use_point_encoder = use_point_encoder
        # Each of these get a linear mapping into the common hidden space.
        self.ecloud_encoder = Conv3DEncoder(d_model=768)
        if norm_clips:
            if old_architecture:
                self.point_to_clip = nn.Sequential(
                    nn.Linear(self.point_encoder.hidden_nf, self.embed_dim),
                    nn.LayerNorm(self.point_encoder.hidden_nf),
                )
                self.smiles_to_clip = nn.Sequential(
                    nn.Linear(self.xformer.n_embd, self.embed_dim),
                    nn.LayerNorm(self.embed_dim),
                )
            else:
                self.point_to_clip = nn.Sequential(
                    nn.LayerNorm(self.point_encoder.hidden_nf),
                    nn.Linear(self.point_encoder.hidden_nf, self.embed_dim),
                )
                self.smiles_to_clip = nn.Sequential(
                    nn.LayerNorm(self.embed_dim),
                    nn.Linear(self.xformer.n_embd, self.embed_dim),
                )
                self.eclouds_to_clip = nn.Sequential(
                    nn.LayerNorm(64 * 768),
                    nn.Linear(64 * 768, self.embed_dim),
                )
        else:
            self.point_to_clip = nn.Linear(self.point_encoder.hidden_nf, self.embed_dim)
            self.smiles_to_clip = nn.Linear(self.xformer.n_embd, self.embed_dim)

        if token_mlp:
            # A mapping to make the special token(s?).
            self.point_clip_to_special_tokens = nn.Sequential(
                nn.SiLU(), nn.Linear(self.embed_dim, self.embed_dim)
            )
        else:
            self.point_clip_to_special_tokens = nn.Identity()

        n_params_e3gnn = sum(p.numel() for p in self.point_encoder.parameters())
        n_params_smiles = sum(p.numel() for p in self.xformer.parameters())
        n_params = n_params_e3gnn + n_params_smiles
        print(
            f"number of parameters Total: {n_params_e3gnn/1e6:.2f}M xformer: {n_params_smiles/1e6:.2f}M Total: {n_params/1e6:.2f}M "
        )
        self.clip_loss = clip_loss()
        self.to(self.device)
    
    
    
    def forward_dist(
        self,
        eclouds: torch.Tensor, 
        augmented_tokens: torch.Tensor,
        atoms: torch.Tensor,
        coords: torch.Tensor,
        tokenizer,
        p_clip_emb_smi: float = 0.4,
    ):
        """
        Same as the below routine but for DistributedDataParallel training.
        """
        with autocast(enabled=False, device_type="cuda"):
            h_e3gnn = encode_points(atoms, coords)
            h_eclouds = encode_eclouds(eclouds)
            try:
                assert h_e3gnn.shape[0] == h_eclouds.shape[0]
            except Exception as Ex:
                print(
                    Ex,
                    augmented_tokens.shape,
                    atoms.shape,
                    coords.shape,
                    h_e3gnn.shape,
                    h_eclouds.shape,
                )
                raise Ex
            # print('h_e3gnn', h_e3gnn.size())
            # print('h_ecloud', h_eclouds.size())
            point_clip_token = self.point_clip_to_special_tokens(h_e3gnn)
            eclouds_clip_token = self.point_clip_to_special_tokens(h_eclouds)
            clip_token = torch.where(
                (torch.rand((h_e3gnn.shape[0],), device=atoms.device) > p_clip_emb_smi)
                .unsqueeze(-1)
                .repeat(1, point_clip_token.shape[-1]),
                point_clip_token,
                eclouds_clip_token,
            )
        logits = self.xformer.forward_with_replacement(     #重新设计吧应该要
            augmented_tokens, clip_token, tokenizer
        )
        bad_rows = augmented_tokens.sum(-1) < 1
        return h_e3gnn, h_eclouds, logits, bad_rows

In [35]:
#args
import argparse
from coati.common.util import makedir, utc_epoch_now
def do_args():
    parser = argparse.ArgumentParser(description="token_transformer")
    parser.add_argument("--exp_name", type=str, default="token_transformer")
    parser.add_argument("--run_name", type=str, default=str(int(utc_epoch_now())))
    parser.add_argument("--output_dir", type=str, default="COATI_outputs")
    parser.add_argument("--model_dir", type=str, default="COATI_models")
    parser.add_argument("--data_dir", type=str, default="COATI_data")

    # ddp options.
    parser.add_argument(
        "-ws", "--world_size", default=1, type=int, help="total number of processes"
    )
    parser.add_argument(
        "-nr", "--nr", default=0, type=int, help="ranking within the nodes"
    )
    parser.add_argument(
        "-n", "--nodes", default=1, type=int, metavar="N", help="number of nodes"
    )
    parser.add_argument(
        "-g",
        "--gpus",
        default=torch.cuda.device_count(),
        type=int,
        help="number of gpus per node",
    )

    parser.add_argument(
        "--device", type=str, default="cuda", help="pytorch backend device."
    )
    parser.add_argument("--dtype", type=str, default="float", help="default data type")
    parser.add_argument("--log_batch_loss", default=25, help="steps per tnet log")
    parser.add_argument(
        "--code_features",
        default=["protein", "secondary", "library"],
        help="one hot encoded additional dimensions.",
    )
    parser.add_argument("--n_epochs", type=int, default=2)
    parser.add_argument("--batch_size", type=int, default=32)

    parser.add_argument(
        "--recipe",
        type=list,
        default=[
            {"collection": "geom_drugs", "n_samples": 6_000_000, "filter": {}},
        ],
    )

    parser.add_argument("--n_layer_e3gnn", type=int, default=4)
    parser.add_argument("--n_hidden_e3nn", type=int, default=128)
    parser.add_argument("--msg_cutoff_e3nn", type=float, default=10.0)
    parser.add_argument("--n_hidden_xformer", type=int, default=128)
    parser.add_argument("--n_embd_common", type=int, default=128)
    parser.add_argument("--n_layer_xformer", type=int, default=16)
    parser.add_argument("--n_head", type=int, default=8)
    parser.add_argument(
        "--biases", type=bool, default=True, help="Use biases in the xformer."
    )
    parser.add_argument("--n_seq", type=int, default=200)
    parser.add_argument("--tokenizer_vocab", type=str, default="Jan8")
    parser.add_argument("--torch_emb", type=bool, default=False)
    parser.add_argument(
        "--load_transformer_only",
        type=bool,
        default=False,
        help="load trained transformer but use fresh point encoder",
    )

    parser.add_argument("--p_dataset", type=float, default=0.3)
    parser.add_argument("--p_formula", type=float, default=0.3)
    parser.add_argument("--p_fim", type=float, default=0.5)
    parser.add_argument("--p_graph", type=float, default=0.3)
    parser.add_argument("--p_clip", type=float, default=0.3)
    parser.add_argument("--p_clip_cut", type=float, default=0.3)

    parser.add_argument("--p_clip_emb_smi", type=float, default=0.4)
    parser.add_argument("--p_randsmiles", type=float, default=0.5)

    parser.add_argument(
        "--norm_clips", type=bool, default=False, help="normalize the clip vectors"
    )
    parser.add_argument(
        "--token_mlp",
        type=bool,
        default=False,
        help="Do we use an MLP or just hclip as a token.",
    )
    parser.add_argument(
        "--norm_embed", type=bool, default=False, help="Layernorm after embedding"
    )
    parser.add_argument("--weight_decay", type=float, default=0.1)
    parser.add_argument("--lr", type=float, default=4e-4)
    parser.add_argument("--clip_grad", type=float, default=10.0)

    parser.add_argument(
        "--do_clip",
        type=bool,
        default=True,
        help="If false, do not use clip loss during training.",
    )

    parser.add_argument(
        "--test_frac", type=float, default=0.02, help="test data fraction"
    )
    parser.add_argument(
        "--valid_frac", type=float, default=0.02, help="test data fraction"
    )
    parser.add_argument(
        "--test_interval",
        type=int,
        default=1,
        metavar="N",
        help="how many epochs to wait before logging test",
    )
    parser.add_argument(
        "--log_interval",
        type=int,
        default=100,
        metavar="N",
        help="how many batches to wait before logging training status",
    )
    parser.add_argument(
        "--ngrad_to_save", default=2e6, help="ngrad updates between model saves."
    )

    parser.add_argument(
        "--resume_document", default=None, help="Restore from an S3 document"
    )
    parser.add_argument(
        "--resume_optimizer",
        type=bool,
        default=False,
        help="Restore opt. from an S3 document",
    )

    args, unparsed_args = parser.parse_known_args()
    # args.cuda = not args.no_cuda and torch.cuda.is_available()
    if len(unparsed_args):
        print("Warning... unparsed: ", unparsed_args)
    return args

In [50]:
#forward
args = do_args()
args.nodes = 1  # total num nodes.
args.nr = 0  # rank of this node.
# note args.gpus will default to the # gpus on this node.
args.data_parallel = True

args.test_frac = 0.02
args.valid_frac = 0.0
args.n_layer_e3gnn = 5
args.n_hidden_e3nn = 256
args.msg_cutoff_e3nn = 12.0
args.n_hidden_xformer = 256
args.n_embd_common = 256
args.n_layer_xformer = 16
args.n_head = 16
args.max_n_seq = 250  # max the model can forward
#    args.n_seq = 90 # max allowed in training.
args.n_seq = 80  # max allowed in training.
args.biases = True
args.torch_emb = False
args.norm_clips = True
args.norm_embed = False
args.token_mlp = True

args.tokenizer_vocab = "mar"
args.p_dataset = 0.2
args.p_formula = 0.0
args.p_fim = 0.0
args.p_graph = 0.0
args.p_clip = 0.9
args.p_clip_emb_smi = 0.5
args.p_randsmiles = 0.3
args.batch_size = 160

args.online = False  # Possible offline training of an end-to-end clip
args.lr = 5.0e-4
args.weight_decay = 0.1

args.dtype = "float"
args.n_epochs = 25
args.clip_grad = 10
args.test_interval = 2
args.debug = False

args.resume_optimizer = False
# resume from checkpoint file
# args.resume_document = ''

args.ngrad_to_save = 2e6

# output logs
args.output_dir = "./logs/"
# where to save model checkpoints
args.model_dir = "./model_ckpts/"
# where to save dataset cache
args.data_dir = "./"
args.model_filename = "coati_grande"
from coati.models.encoding.tokenizers import get_vocab
from coati.models.encoding.clip_e2e import clip_loss as clip_loss_module
# tokenizer = TrieTokenizer(n_seq=args.n_seq, **get_vocab(args.tokenizer_vocab)) 'mar'
tokenizer = TrieTokenizer(n_seq=args.n_seq, **get_vocab('mar'))
token_entropy_unit = np.log(float(len(tokenizer.keys))) / np.log(2.0)
kwargs = {
    "n_layer_xformer": args.n_layer_xformer,
    "n_layer_e3gnn": args.n_layer_e3gnn,
    "n_hidden_e3nn": args.n_hidden_e3nn,
    "n_hidden_xformer": args.n_hidden_xformer,
    "n_embd_common": args.n_embd_common,
    "biases": args.biases,
    "n_head": args.n_head,
    "n_seq": args.max_n_seq,
    "n_tok": tokenizer.n_token,  # base
    "torch_emb": args.torch_emb,
    "norm_clips": args.norm_clips,
    "norm_embed": args.norm_embed,
    "token_mlp": args.token_mlp,
    }
model = e3gnn_eclouds_clip_e2e(**kwargs)
clip_computer = clip_loss_module()
h_e3gnn, h_xformer, logits, bad_rows = model.forward_dist(
    eclouds.to(device),
    augmented_tokens.to(device),
    atomic_number.to(device),
    coords.to(device),
    tokenizer,
    p_clip_emb_smi=args.p_clip_emb_smi,
)
print('h_e3gnn', h_e3gnn.size())
print('h_xformer', h_xformer.size())
print('logits', logits.size())
print('bad_rows', bad_rows.size())

#loss
bad_rows = bad_rows
all_h_xformer = h_xformer
all_h_e3gnn = h_e3gnn
y_next = torch.zeros_like(augmented_tokens).to(device)
y_next[:, :(augmented_tokens.shape[1] - 1)] = \
    torch.Tensor(augmented_tokens).clone()[:, 1:]
y_next[y_next == tokenizer.clip_token] = -1
y_next[y_next == tokenizer.pad_token] = -1
y_next[y_next == tokenizer.smiles_token] = -1
y_next[y_next == tokenizer.unk_token] = -1
y_next[y_next == tokenizer.suffix_token] = -1
y_next[y_next == tokenizer.middle_token] = -1

ar_loss_ = torch.nn.functional.cross_entropy(
                logits.view(-1, logits.size(-1)),
                y_next.view(-1).long(),
                ignore_index=-1,
            )
ar_loss = ar_loss_.mean()

if args.do_clip:
    clip_loss_ = clip_computer(all_h_xformer, all_h_e3gnn, bad_rows) 
    # clip_loss_ = clip_computer(h_xformer, h_e3gnn, bad_rows) 
    clip_loss = clip_loss_.mean()
    loss = ar_loss + clip_loss * token_entropy_unit
else:
    loss = ar_loss

print('loss', loss)

number of parameters: 12.64M
number of parameters Total: 2.44M xformer: 19.60M Total: 22.04M 
h_e3gnn torch.Size([2, 256])
h_xformer torch.Size([2, 256])
logits torch.Size([2, 128, 13603])
bad_rows torch.Size([2])
