In [13]:
import os
import numpy as np
import pandas as pd
from mendeleev import element
from collections import Counter
import matplotlib

#matplotlib.use('Agg')
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import r2_score
from sklearn.decomposition import PCA

import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, TensorDataset
from torch.nn.functional import mse_loss, l1_loss
import math
import sys


from msa_dist import build_model
from utils import parse_args
sys.path.append("..")
#from expand import neighbor_cells
import plotly.offline as pyoff
import plotly.graph_objs as go
import plotly.io as pio
%matplotlib inline
import pymatgen

In [14]:
import argparse
def parse_args():
    # 设置模型结构参数
    parser.add_argument('--pretrain', type=str, default='', help='Whether to load the pretrained model weights.')
    parser.add_argument('--atom_class', type=int, default=100, help='The default number of atom classes + 1.')
    parser.add_argument('--n_encoder', type=int, default=1, help='Number of stacked encoder.')
    parser.add_argument('--embed_dim', type=int, default=512, help='Dimension of PE, embed_dim % head == 0.')
    parser.add_argument('--head', type=int, default=1, help='Number of heads in multi-head attention.')
    parser.add_argument('--dropout', type=float, default=0.1, help='Dropout rate.')
    parser.add_argument('--max_len', type=int, default=1024, help='Maximum length for the positional embedding layer.')

    # 设置训练的方式
    parser.add_argument('--seed', type=int, default=1234)
    parser.add_argument('--split_ratio', type=float, default=0.9)
    parser.add_argument('--epochs', type=int, default=300, help='Number of epoch.')
    parser.add_argument('--bs', type=int, default=32, help='Batch size.')
    parser.add_argument('--lr', type=float, default=0.0005, help='Learning rate.')

    # 设置训练环境配置
    parser.add_argument('--gpu_id', type=str, default='1', help='Index for GPU')
    parser.add_argument('--save_path', default='save/', help='Path to save the model and the logger.')

    return parser.parse_args(args=[])

In [16]:
def trace_ge(sample,idx,i,linewidth,col):
    return go.Scatter3d(x=[sample[idx[0][i], 0], sample[idx[1][i], 0]],y=[sample[idx[0][i], 1], sample[idx[1][i], 1]], z=[sample[idx[0][i], 2], sample[idx[1][i], 2]],line=dict(width=linewidth,color=col),mode='lines',marker=dict(size=30,color=col, opacity=1))

In [17]:
def plotplot(sample,idx,size1,COSC,head,name,linewidth,col):
    
 
    trace1 = go.Scatter3d(
    x=sample[1:,0].numpy(),
    y=sample[1:,1].numpy(),
    z=sample[1:,2].numpy(),mode='markers',
    marker=dict(
        size=size1,
        color=COSC,                # set color to an array/list of desired values   # choose a colorscale
        opacity=0.8

        ))
    data = [trace1]
    if len(idx[0]) > 0:
        for i in range(len(idx[0])):
            if idx[0][i] != 0 and idx[1][i] != 0:
                data.append(trace_ge(sample,idx,i,linewidth,col))

    #data = [trace1,trace2]
    layout = go.Layout(template='simple_white',scene=dict(
            xaxis=dict(
                title='',ticks='',linecolor='white',showticklabels=False
            ),
            yaxis=dict(
                title='',ticks='',linecolor='white',showticklabels=False
            ),
            zaxis=dict(
                title='',ticks='',linecolor='white',showticklabels=False
            ),camera=dict(
                eye=dict(x=0,y=0,z=2),
                center=dict(x=0,y=0,z=0),
                up=dict(x=0,y=0,z=0)
                
            )
        ),width=900,height=900*0.618)
 
    fig = go.Figure(data=data, layout=layout)
    pio.write_image(fig,'name+str(col)+'_'+str(head)+'head'+'_20220412_atten.png',width=1000,height=1000, format='png')
    pio.write_html(fig,'name+'\\'+name+str(col)+'_'+str(head)+'head'+'_20220412_atten.html',default_width='3000px',default_height='3000px')
    return fig

In [18]:
MOF_LIST=os.listdir('..\3D attention visualization\Example_cif')

In [19]:
MOF_LIST=['UTSA-74.cif']

In [9]:
for i in MOF_LIST:
    os.makedirs(i+'20220623')

In [23]:
#plt.style.use(['tableau-colorblind10'])
import plotly.offline as pyoff
import plotly.graph_objs as go
import plotly.io as pio
#k=1
def get_att(mof_name):
    j='blue'
    linewidth=1.5
    file_path='..\3D attention visualization\Example_cif\\'+mof_name
    cpt_path='Coremof_model.pt'
    bar=0.0
    cls=False
    args = parse_args()
    atom_class=100
    a=pymatgen.Structure.from_file(file_path)
    if (len(a)<=128)&(len(a)>=38):
        a.make_supercell(2)
    if (len(a)<=37):
        a.make_supercell(3)
    
    e=torch.Tensor(a.atomic_numbers).reshape(-1,1)
    c=torch.Tensor(a.cart_coords)
    sample=torch.cat((c,e),axis=1)
    sample_c = torch.cat((torch.Tensor(a.lattice.abc), torch.ones(1).double() * atom_class), dim=0)
    sample = torch.cat((sample_c.unsqueeze(0), sample), dim=0)

    cpt = torch.load(cpt_path)
    model = build_model(atom_class + 1, tgt=7, dist_bar=cpt['dist_bar'],N=6,head=4,dropout=0).cuda()
    model.load_state_dict(cpt['model'])
    model.eval()
    with torch.no_grad():
        x = sample.unsqueeze(0).cuda()
        mask = (x[..., 3] != 0).unsqueeze(1)
        _ = model(x[..., 3].long(), mask, x[..., :3])

    att_all = []
    n = 0
    for head in range(4):
        for layer in range(6):
            att_all.append(model.encoder.layers[layer].self_attn.attn[0, head].data.cpu())
            

        n += 1
    head1=[att_all[0],att_all[1],att_all[2],att_all[3],att_all[4],att_all[5]]
    head2=[att_all[6],att_all[7],att_all[8],att_all[9],att_all[10],att_all[11]]
    head3=[att_all[12],att_all[13],att_all[14],att_all[15],att_all[16],att_all[17]]
    head4=[att_all[18],att_all[19],att_all[20],att_all[21],att_all[22],att_all[23]]
    head=[head1,head2,head3,head4]
    return head,sample

In [26]:
def plot_final(head,sample,mof_name):
    for head_num in range(4):
        j='blue'
        linewidth=1.5
   

        weight = [element(int(i)).atomic_weight for i in sample[:, 3]]
        name = [element(int(i)).symbol for i in sample[:, 3]]

        COSC=[]
        size1=[]
        a=15
        for i in list(sample[1:,3].long().numpy()):
            if i==1:#H
                COSC.append((200/255,200/255,200/255))
                size1.append(0.5*a)
            if i==6:#C
                COSC.append((127/255,127/255,127/255))
                size1.append(1.2*a)
            if i==7:#N
                COSC.append((0/255,176/255,240/255))
                size1.append(1.2*a)
            if i==8:#0
                COSC.append((223/255,15/255,90/255))
                size1.append(1.5*a)
            if i==9:#F
                COSC.append((224/255,17/255,89/255))
                size1.append(2*a)
            if (i>=11)&(i<=17):#Si
                COSC.append((98/255,193/255,189/255))
                size1.append(2.5*a)
            if (i>=29)&(i<=86):#Cu
                COSC.append((38/255,175/255,103/255))
                size1.append(2.5*a)
           
        for m in range(6):
            att=head[head_num][m]
            idx = (att >att[1:,1:].reshape(-1).sort()[0][-20]).nonzero(as_tuple=True)
            if m==0:
                idx_all=idx
            else:
                idx_all=(torch.cat((idx_all[0],idx[0])),torch.cat((idx_all[1],idx[1])))

      
        try:
            fig=plotplot(sample,idx_all,size1,COSC,head_num+1,mof_name,linewidth,j)
        except:
            print(str(j)+' failed')
        print(j)
        print((idx[0]!=0).sum().numpy())

In [None]:
for i in MOF_LIST:
    try:
        head=get_att(i)
        plot_final(head[0],head[1],i)
    except:
        print(i+'failed')
