In [1]:
import torch 
import torch.nn as nn
from torch_scatter import scatter


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class Positional_embedding(nn.Module):
    pass

In [133]:
class Gemformer(nn.Module):
    '''This is a Graph Transformer neural network, 
    the input is the Gemnet-oc latent space with the best pre-trained params
    
    Parameters
    num_heads:
        Number of heads
    '''

    def __init__(self,num_heads,emb_size_in,emb_size_trans,out_layer1=32,out_layer2=1):
        super(Gemformer, self).__init__()
        self.num_heads=num_heads
        
        self.out_layer1=out_layer1
        self.out_layer2=out_layer2
        self.dense=nn.Sequential(nn.Linear(emb_size_trans,out_layer1),
                                 nn.SiLU(),
                                 nn.Linear(out_layer1,out_layer2)                                 
                                 )

        self.lin_query_MHA=nn.Linear(emb_size_in,emb_size_trans)
        self.lin_key_MHA=nn.Linear(emb_size_in,emb_size_trans)
        self.lin_value_MHA=nn.Linear(emb_size_in,emb_size_trans)

        self.softmax=nn.Softmax(dim=1)

        self.MHA=nn.MultiheadAttention(embed_dim=emb_size_trans,
                                       num_heads=num_heads,
                                       bias=True,
                                       dropout=0.0,
                                       )
        self.layer_norm = nn.LayerNorm(emb_size_trans)
        
    def check_shape(self,va):
        '''check the varaible shape,mean and std,
           only for develop the model'''
        print(f'The {va.__class__.__name__} shape is:',va.shape)
        print(f'The {va.__class__.__name__} mean and std of is:',va.mean(),va.std())
        

    def forward(self,data):

        E_all=data.latent
 
        batch = data.batch
        # print(batch)
        q=self.lin_query_MHA(E_all)
        k=self.lin_key_MHA(E_all)
        v=self.lin_value_MHA(E_all)

        nMolecules = torch.max(batch) + 1
        E_t,w=self.MHA(q,k,v)
        # self.check_shape(E_t)
        E_t=torch.sum(E_t,dim=0)
        # self.check_shape(E_t)
        E_t = self.layer_norm(E_t)
        # self.check_shape(E_t)
        
        E_t = scatter(
                E_t, batch, dim=0, dim_size=nMolecules, reduce="add"
            )  # (nMolecules, num_targets)
        # self.check_shape(E_t)
        
        E_t=self.dense(E_t)
        # self.check_shape(E_t)

        
        return E_t
    

In [134]:
DEVICE='cuda'
def out_fn(dataloader,model):

    model.eval()    
    with torch.no_grad():
        for data in dataloader:
            
            data=data.to(DEVICE)  
            model=model.to(DEVICE) 
            output=model(data)
            break
            
    return output 

In [135]:
myTransformer=Gemformer(1,256,64)

In [136]:
from ocpmodels.datasets import LmdbDataset
dataset=LmdbDataset({"src":"Data/eoh_t.lmdb"})
import torch_geometric.loader as geom_loader
node_data_loader = geom_loader.DataLoader(dataset, batch_size=1)
output=out_fn(node_data_loader,myTransformer)

The Tensor shape is: torch.Size([1, 64])
The Tensor mean and std of is: tensor(-1.0133e-06, device='cuda:0') tensor(54.8691, device='cuda:0')
The Tensor shape is: torch.Size([1, 1])
The Tensor mean and std of is: tensor(0.6330, device='cuda:0') tensor(nan, device='cuda:0')


In [137]:
output.shape,output

(torch.Size([1, 1]), tensor([[0.6330]], device='cuda:0'))

In [138]:
# scaling_file=torch.load('params/gemnet_oc_base_oc20_oc22_scales.pt')
# scaling_file