In [None]:
import torch
import os
import torch_geometric.transforms as T
from torch_geometric.loader import DataLoader
from models.regression_train_test import train, test, predictingSingle
from models.Predictor_resgatedgraphconvN import MolRGNNPredictorPredictor
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"  
os.environ["CUDA_VISIBLE_DEVICES"]="0"
from torch_geometric.utils import from_smiles
from torch_geometric.data import InMemoryDataset #, Data
from rdkit import Chem
from rdkit.Chem import Draw
import matplotlib.pyplot as plt
import numpy as np

In [2]:
seed = 120
import time
if torch.cuda.is_available():  
    device = "cuda:4"
    torch.cuda.manual_seed_all(seed)
    # print("cuda:4")
else:  
    device = "cpu" 
    # print(torch.cuda.is_available())
device = "cpu"

In [3]:
class Molecule_data(InMemoryDataset):
    def __init__(self, root='/tmp', dataset='davis', y=None, transform=None,
                 pre_transform=None,smiles=None):

        #root is required for save preprocessed data, default is '/tmp'
        super(Molecule_data, self).__init__(root, transform, pre_transform)
        self.dataset = dataset
        if os.path.isfile(self.processed_paths[0]):
            self.data, self.slices = torch.load(self.processed_paths[0])
        else:
            self.process(smiles)
            self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        pass

    @property
    def processed_file_names(self):
        return [self.dataset + '.pt']

    def download(self):
        pass

    def _download(self):
        pass

    def _process(self):
        if not os.path.exists(self.processed_dir):
            os.makedirs(self.processed_dir)

    def process(self, smiles):
       
        data_list = []
        for i in range(len(smiles)):
            smile = smiles[i]
            data = from_smiles(smile)
            data.x = (data.x).type(torch.FloatTensor)
            data.edge_attr = (data.edge_attr).type(torch.FloatTensor)
            data.smile_fingerprint = None
            graph = data
            
            data_list.append(graph)

        if self.pre_filter is not None:
            data_list = [data for data in data_list if self.pre_filter(data)]

        if self.pre_transform is not None:
            data_list = [self.pre_transform(data) for data in data_list]
        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])

In [4]:

def Predictions(smilesList,targetToPredicate):
    transform = T.Compose([T.NormalizeFeatures(['x', 'edge_attr'])])
    
    test_data_set = 'test_data_set'+str(time.time_ns())
    test_data = Molecule_data(root='data', dataset=test_data_set,y=None,
                               smiles=smilesList, transform=transform)
    noveltest_loader  = DataLoader(test_data,batch_size=64,shuffle=True)
    
    model = MolRGNNPredictorPredictor().to('cpu')
    model_file_name = 'saved_models/' + targetToPredicate +'model.model'
    checkpoint = torch.load(model_file_name, map_location=torch.device('cpu'))
    model.load_state_dict(checkpoint)

    prediction = predictingSingle(noveltest_loader, model,'cpu')

    return prediction


In [5]:
def smilesToMolecularStructure(smiles):
    images = []
    for smile in smiles:
        mol = Chem.MolFromSmiles(smile)
        if mol:
            img = Draw.MolToImage(mol)
            images.append(img)
        else:
            images.append(None)

    return images

In [6]:
def makePredictions(inputSmiles):
    targets_to_predicates = ['voc','e_gap_alpha','jsc','pce']
    predictions_array = {}
    for target in targets_to_predicates:
        predictions = Predictions(inputSmiles, target)
        predictions_array[target] = predictions
    return predictions_array

In [7]:
smileslist = ['C12=C(C=C3C(=C1)C=CC=C3)C=C4C(=C2)C=CC=C4']
predictions_result = makePredictions(smileslist)
print(predictions_result)

{'voc': array([0.7608636], dtype=float32), 'e_gap_alpha': array([2.5942693], dtype=float32), 'jsc': array([26.251654], dtype=float32), 'pce': array([2.4244654], dtype=float32)}
