In [11]:
from torch import load
import torch
import sys
import numpy as np
from schnetpack.data import AtomsData
import schnetpack as spk
import ase.db
import ase.io
from ase.io import read,write
import joblib
import lightgbm as lgb
import os
from openbabel import openbabel
from rdkit import Chem

In [12]:
def load_model(target,path):
    if target!='PCE':
        model1=load(path, map_location=torch.device('cpu'))
        model1.eval()
        return model1
    else:
        model= joblib.load(path)
        
        return model
    

In [13]:
def cal_nd(mol):
    atoms=mol.toatoms()
    write('mol.xyz',atoms)
    obConversion = openbabel.OBConversion()
    obConversion.SetInAndOutFormats("xyz", "mol")
    mol = openbabel.OBMol()
    obConversion.ReadFile(mol, "mol.xyz")   # Open Babel will uncompress automatically
    mol.AddHydrogens() 
    obConversion.WriteFile(mol, '1.mol')

    #calculate Nd         
    mol = Chem.MolFromMolFile('1.mol')        
    n = len(mol.GetAtoms())         
    Nd = 0        
    for i in range(0,n):
        atom = mol.GetAtomWithIdx(i)
        #判断原子是否为芳香性
        if atom.GetIsAromatic() == True:
            Nd += 1
        if atom.GetIsAromatic() == False:
            #判断原子价电子是否等于总饱和度
            if atom.GetTotalValence() != atom.GetTotalDegree():
                Nd += 1
            if atom.GetTotalValence() == atom.GetTotalDegree():
                #判断原子是否在环上
                if atom.IsInRing() == True:
                    Nd += 1
        
        
    return Nd

In [14]:
def cal_prop(moln,molo,tag):
    
    
    al=.0
    if molo.data.Acceptor=='PC61BM':
        al= -3.70
        adl= 0.077824564
    if molo.data.Acceptor=='PC71BM':
        al= -3.91
        adl= 0.033470005
    if tag=='edahl':
        prop=al-float(moln.homo)
    if tag=='edall':
        prop=float(moln.lumo)-al
    if tag=='adlumo':
        prop=adl
    if tag=='nd':
        prop=cal_nd(moln)


    return prop

In [15]:
def pred_data( model,tag,data):
     
            
    if tag== 'PCE':
        return pred_pce(model,data)
        
    else :
         return pred_prop(model,tag,data)    
             

In [16]:
def pred_pce(model,data):
    db=ase.db.connect(data)
    pce=[]
    ids=[]
    for row in db.select():
        x=[]
        x.extend((row.homo,row.lumo,row.edahl,row.edall,row.et1,row.nd,row.adlumo,row.dhomo,row.dlumo))
        print(x)
        y = model.predict(np.array(x).reshape(1,-1)).tolist()
        print(y)
        pce.extend(y)
        ids.append(row.id)
        
    return ids,pce

In [17]:
def pred_prop(model,tag,data):
    pred=AtomsData(data)
    pred_loader = spk.AtomsLoader(pred, batch_size=10) #40!!
    for count, batch in enumerate(pred_loader):
        datapred = model(batch)
        ids=batch['_idx'].numpy().tolist()
        datapred=datapred[tag].detach().numpy().tolist()
        yield datapred,ids

In [18]:
def write_results(predata,tag,db):
    
    for num in predata.keys():
        for prop in predata[num].keys():
            
            db.update(id=num+1, **{prop: predata[num][prop]}) 
    
    return 0

In [19]:
def main():
    target=['et1','dhomo','dlumo','homo','lumo'] # need to predict with schnet
    target2=['nd','edahl','edall','adlumo'] # no need to predict
    predata={}
    db=ase.db.connect('predCSDSC.db')
    
    odb=ase.db.connect('CSDSC.db')
    for mol in odb.select():      
            atom=mol.toatoms()
            db.write(atom,name = mol.data.name)
    for tag in target:
        best_model=load_model(target=tag,path='./package/'+tag+'_model')
    
        for property,id in pred_data(best_model,tag,data='predCSDSC.db'):
            for sid,sprop in zip(id,property):    
                predata.update({sid[0]:{tag:sprop[0]}})
        write_results(predata,tag,db)    
    
    for tag in target2:
        for moln,molo in zip(db.select(),odb.select()):
            sprop=cal_prop(moln,molo,tag)
            sid=moln.id-1
            predata.update({sid:{tag:sprop}})
           
        write_results(predata,tag,db) 
    pcemodel=load_model(target='PCE',path='./package/lgb_model')
    
    ids,pce=pred_data(model=pcemodel,tag='PCE',data='predCSDSC.db')

    for sid,spce in zip(ids,pce):
        
        db.update(id=sid,PCE=spce)

    
    return 0

In [20]:
if __name__ == '__main__':
    status = main()
    
    



[-4.803012847900391, -2.5448944568634033, 0.8930128479003905, 1.3651055431365968, 1.266765832901001, 54, 0.033470005, 0.5702119469642639, 0.3076286315917969]
[5.377287420320842]
[-4.6082682609558105, -2.407496213912964, 0.6982682609558104, 1.5025037860870363, 1.369258165359497, 43, 0.033470005, 0.8138619661331177, 0.7633200287818909]
[4.434557877217224]
[-4.7147698402404785, -2.6450564861297607, 0.8047698402404784, 1.2649435138702394, 1.4089423418045044, 74, 0.033470005, 0.3619072735309601, 0.24833713471889496]
[4.678312754688934]
[-4.782886028289795, -2.423271894454956, 0.8728860282897948, 1.486728105545044, 1.4600920677185059, 47, 0.033470005, 0.697253406047821, 0.40263238549232483]
[5.000191014852801]
[-4.867783069610596, -2.4019241333007812, 0.9577830696105956, 1.508075866699219, 1.479724645614624, 52, 0.033470005, 0.5940201878547668, 0.19581235945224762]
[5.290166825045445]
[-4.803682327270508, -2.5825769901275635, 0.8936823272705077, 1.3274230098724367, 1.388946294784546, 52, 0.0

[4.79097549088464]
[-4.722800254821777, -2.27439284324646, 0.8128002548217772, 1.6356071567535402, 1.5196336507797241, 47, 0.033470005, 0.6479974389076233, 0.9320326447486877]
[4.1181291196182155]
[-4.773195266723633, -2.696199655532837, 0.8631952667236327, 1.2138003444671632, 1.4177546501159668, 78, 0.033470005, 0.34198126196861267, 0.3387451469898224]
[4.7990306106629905]
[-4.803099155426025, -2.5008628368377686, 0.8930991554260252, 1.4091371631622316, 1.5535019636154175, 51, 0.033470005, 0.5813171863555908, 0.45805367827415466]
[4.970578236754236]
[-4.8598313331604, -2.577986240386963, 0.9498313331604002, 1.3320137596130373, 1.5803375244140625, 56, 0.033470005, 0.5002591013908386, 0.25157374143600464]
[5.157293378065455]
[-4.766946315765381, -2.6386451721191406, 0.8569463157653807, 1.2713548278808595, 1.4350900650024414, 58, 0.033470005, 0.5294405221939087, 0.29432985186576843]
[5.273805144472121]
[-4.6294708251953125, -2.358267068862915, 0.7194708251953124, 1.551732931137085, 1.572