In [1]:
# add chempe
import sys
sys.path.append('../..')


In [2]:
import numpy as np
import os

# Step 1 Prediction

In [3]:
import rdkit
from rdkit import Chem
import tqdm
from chempe.data.warnings import disable_warnings, enable_warnings
from collections import Counter


In [4]:
#model = 'gpt-4'
model = 'gpt-4'

s1_pred_path = f'../../results/step1_prediction/{model}/step1_prediction_results.npy'
s1_results = np.load(s1_pred_path, allow_pickle=True).item()
s1_pred_sc_path = f'../../results/step1_prediction_sc/{model}/step1_prediction_sc_results.npy'
s1_sc_results = np.load(s1_pred_sc_path, allow_pickle=True).item()
s1_pred_zs_path = f'../../results/step1_prediction_zs/gpt-4/step1_prediction_zs_results.npy'
s1_zs_results = np.load(s1_pred_zs_path, allow_pickle=True).item()

In [5]:
# Loads indices
test_indices = np.load("../../data/test_indices.npy", allow_pickle=True)
# number of tests
n = 20

In [54]:
# evaluate results
def eval_baseline(rez, warnings = False):
    """Only 1 result"""
    if not warnings:
        disable_warnings()
        
    top1 = 0
    invalid = 0
    res_type = [0] * n


    for i in range(n):
        # canonical gt
        #gt = Chem.CanonSmiles(rez['ground_truth'][test_indices[i]])
        gt = rez['ground_truth'][test_indices[i]]
        predicted = rez['predicted'][test_indices[i]][0]
        try:
            pred = Chem.CanonSmiles(predicted)
            if gt == pred:
                top1 += 1
                res_type[i] = (1, pred, gt)
            else:
                res_type[i] = (-1, pred, gt)
        except Exception as e:
            # if failed to parse
            invalid += 1
            res_type[i] = (-2, None, gt)
        
    print(f"Num correct: {top1}/{n}")
    print(f"Num Invalid: {invalid}/{n}")
    
    enable_warnings()
    return res_type
    
def eval_sc(rez, warnings = False):
    """Use consistency"""
    if not warnings:
        disable_warnings()
    
    top1 = 0
    top5 = 0
    invalid = 0
    res_type = [-1] * n
    counts = []
    
    def count_preds(predicted):
        canon_pred = []
        tmp = 0
        for p in predicted:
            try:
                canon_pred.append(Chem.CanonSmiles(p))
            except:
                tmp += 1
                pass
        #print(tmp)
        return Counter(canon_pred)
        

    for i in range(n):
        # canonical gt
        gt = Chem.CanonSmiles(rez['ground_truth'][test_indices[i]])
        counter = count_preds(rez['predicted'][test_indices[i]])
        if len(counter) == 0:
            invalid += 1
            res_type[i] = (-2, None, gt)
            continue

        # top 1
        pred = counter.most_common()[0]
        if gt == pred[0]:
            counts.append((pred[1],))
            top1 += 1
            top5 += 1
            res_type[i] = (1, pred[0], gt)
            continue
        
        pred5 = counter.most_common(n = 5)
        pred5, cs = list(zip(*pred5))
        if gt in pred5:
            counts.append(cs)
            top5 += 1
            res_type[i] = (5, pred5, gt)
            continue
        
        res_type[i] = (-1, pred[0], gt)
            
        
            
    print(f"Top1 Correct: {top1}/{n}")
    print(f"Top5 Correct: {top5}/{n}")
    print(f"Num Invalid: {invalid}/{n}")
    print(counts)
    enable_warnings()
    return res_type


        

In [7]:
_= eval_baseline(rez=s1_results)

Num correct: 18/20
Num Invalid: 0/20


In [8]:
_ = eval_sc(s1_sc_results)

Top1 Correct: 20/20
Top5 Correct: 20/20
Num Invalid: 0/20
[(20,), (20,), (20,), (20,), (20,), (18,), (20,), (20,), (19,), (20,), (12,), (20,), (17,), (20,), (7,), (20,), (20,), (20,), (20,), (20,)]


In [10]:
_ = eval_baseline(s1_zs_results)

Num correct: 1/20
Num Invalid: 4/20


In [28]:
s1_full_results = np.load('../../results/step1_prediction_full/gpt-4/step1_prediction_full_results.npy', allow_pickle=True).item()
_ = eval_baseline(s1_full_results)

Num correct: 11/20
Num Invalid: 4/20


In [53]:
Chem.MolToSmiles(x)

'[CH3:18][S:19][C:20]#[N:21].[OH:5][c:6]1[cH+:7][cH:8][cH:9][cH:10][cH:11]1'

In [58]:
s2_results = np.load('../../results/step2_prediction/gpt-4/step2_prediction_results.npy', allow_pickle=True).item()
eval_baseline(s2_results)

Num correct: 1/20
Num Invalid: 1/20


[(-1,
  '[CH2:30]1[O:31][CH2:32][CH2:33][CH2:34]1.[CH:1]([CH3:2])([CH3:3])[NH:4][CH:5]([CH3:6])[CH3:7].[c:13]1([C:19](=[CH:20][CH2:21][C:22](=[O:23])[OH:24])[CH3:25])[cH:14][cH:15][cH:16][cH:17][cH:18]1',
  '[CH:1]([CH3:2])([CH3:3])[NH:4][CH:5]([CH3:6])[CH3:7].[c:13]1([C:19](=[CH:20][CH2+:21][C:22](=[O:23])[OH:24])[CH3:25])[cH:14][cH:15][cH:16][cH:17][cH:18]1'),
 (-1,
  '[C:17]([C:18]([Cl:19])=[O:20])=[O:21].[CH:1]1([CH3:15])[CH2:2][CH:3]([O:10][CH2:11][C:12](=[O:13])[OH:14])[CH:4]([CH:7]([CH3:8])[CH3:9])[CH2:5][CH2:6]1',
  '[C:17]([C:18]([Cl:19])=[O:20])=[O:21].[CH:1]1([CH3:15])[CH2:2][CH:3]([O:10][CH2:11][C:12](=[O:13])[OH:14])[CH:4]([CH:7]([CH3:8])[CH3:9])[CH2:5][CH2:6]1.[Cl:16]'),
 (-1,
  '[O:1]=[S:2]1(=[O:30])[N:3]=[C:4]([c:12]2[c:13](=[O:29])[n:14]([N:23]=[CH:24][CH2:25][CH2:26][CH2:27][CH3:28])[c:15]3[cH:16][cH:17][cH:18][cH:19][c:20]3[c:21]2[OH:22])[NH+:5][c:6]2[c:7]1[cH:8][cH:9][cH:10][cH:11]2',
  '[O:1]=[S:2]1(=[O:30])[N:3]=[C:4]([c:12]2[c:21]([OH:22])[c:20]3[cH:19][cH:18][cH

In [61]:
print(s2_results['prompts'][test_indices[0]])

You are an expert chemist. Your task is to predict the next intermediate molecules in the electron transfer process given the reactants and reagents, several examples, and your experienced reaction prediction knowledge. There are some rules to follow.
1. Strictly follow the given format, and only respond with an atom mapped SMILES string.
2. Numbers immediately following : represent the atom mapping.
3. A . is used to distinguish between multiple molecules in the SMILES strings.
4. If no reagents exist, it is left blank.
Reactants: [Cl:31][C:32](=[O:33])[O:34][c:35]1[cH:36][cH:37][c:38]([N+:41](=[O:42])[O-:43])[cH:39][cH:40]1.[o:1]1[c:2](-[c:6]2[o:7][c:8]([CH3:30])[c:9]([CH2:11][O:12][c:13]3[c:14]([O:28][CH3:29])[cH:15][c:16]([CH2:19][CH2:20][CH2:21][CH:22]([C:23](=[O:24])[O:25][CH3:26])[OH:27])[cH:17][cH:18]3)[n:10]2)[cH:3][cH:4][cH:5]1
Reagents: [ClH:45].[OH2:44].[cH:46]1[cH:47][cH:48][n:49][cH:50][cH:51]1
Intermediate 1: [Cl:31][C:32](=[O:33])[O:34][c:35]1[cH:36][cH:37][c:38]([N+:41

In [59]:
s2_results = np.load('../../results/step2_prediction_partial/gpt-4/step2_prediction_partial_results.npy', allow_pickle=True).item()
eval_baseline(s2_results)

Num correct: 0/20
Num Invalid: 1/20


[(-1,
  '[CH:1]([CH3:2])([CH3:3])[NH+:4][CH:5]([CH3:6])[CH3:7].[c:13]1([C:19](=[CH:20][CH2:21][C:22](=[O:23])[OH:24])[CH3:25])[cH:14][cH:15][cH:16][cH:17][cH:18]1',
  '[CH:1]([CH3:2])([CH3:3])[NH:4][CH:5]([CH3:6])[CH3:7].[c:13]1([C:19](=[CH:20][CH2+:21][C:22](=[O:23])[OH:24])[CH3:25])[cH:14][cH:15][cH:16][cH:17][cH:18]1'),
 (-1,
  '[CH:1]1([CH3:15])[CH2:2][CH:3]([O:10][CH2:11][C:12](=[O:13])[OH:14])[CH:4]([CH:7]([CH3:8])[CH3:9])[CH2:5][CH2:6]1.[Cl:16][C:17]([C:18]([Cl:19])=[O:20])=[O:21].[cH:22]1[cH:23][cH:24][cH:25][cH:26][cH:27]1',
  '[C:17]([C:18]([Cl:19])=[O:20])=[O:21].[CH:1]1([CH3:15])[CH2:2][CH:3]([O:10][CH2:11][C:12](=[O:13])[OH:14])[CH:4]([CH:7]([CH3:8])[CH3:9])[CH2:5][CH2:6]1.[Cl:16]'),
 (-1,
  '[O:1]=[S:2]1(=[O:30])[N:3]=[C:4]([c:12]2[c:13](=[O:29])[n:14]([N:23]=[CH:24][CH2:25][CH2:26][CH2:27][CH3:28])[c:15]3[cH:16][cH:17][cH:18][cH:19][c:20]3[c:21]2[OH:22])[NH2+:5][c:6]2[c:7]1[cH:8][cH:9][cH:10][cH:11]2',
  '[O:1]=[S:2]1(=[O:30])[N:3]=[C:4]([c:12]2[c:21]([OH:22])[c:20]3[cH:

# CoT prediction

In [20]:
# evaluate results
def cot_eval_baseline(rez, warnings = False):
    """Only 1 result"""
    if not warnings:
        disable_warnings()
        
    top1 = 0
    invalid = 0
    res_type = [0] * n


    for i in range(n):
        # canonical gt
        gt = Chem.CanonSmiles(rez['ground_truth'][test_indices[i]])
        predicted = rez['predicted'][test_indices[i]][0]
        try:
            pred = Chem.CanonSmiles(predicted)
            if gt in pred.split():
                top1 += 1
                res_type[i] = (1, pred.split(), gt)
            else:
                res_type[i] = (-1, pred.split(), gt)
        except Exception as e:
            # if failed to parse
            invalid += 1
            res_type[i] = (-2, None, gt)
        
    print(f"Num correct: {top1}/{n}")
    print(f"Num Invalid: {invalid}/{n}")
    
    enable_warnings()
    return res_type
    
def cot_eval_sc(rez, warnings = False):
    """Use consistency"""
    if not warnings:
        disable_warnings()
    
    top1 = 0
    top5 = 0
    invalid = 0
    res_type = [-1] * n
    counts = []
    
    def count_preds(predicted):
        canon_pred = []
        tmp = 0
        for p in predicted:
            try:
                canon_pred.append(Chem.CanonSmiles(p))
            except:
                tmp += 1
                pass
        #print(tmp)
        return Counter(canon_pred)
        

    for i in range(n):
        # canonical gt
        gt = Chem.CanonSmiles(rez['ground_truth'][test_indices[i]])
        counter = count_preds(rez['predicted'][test_indices[i]])
        if len(counter) == 0:
            invalid += 1
            res_type[i] = (-2, None, gt)
            continue

        # top 1
        pred = counter.most_common()[0]
        if gt in pred[0].split():
            counts.append((pred[1],))
            top1 += 1
            top5 += 1
            res_type[i] = (1, pred[0].split(), gt)
            continue
        
        pred5 = counter.most_common(n = 5)
        pred5, cs = list(zip(*pred5))
        cont = False
        for preds in pred5:
            if gt in preds:
                counts.append(cs)
                top5 += 1
                res_type[i] = (5, pred5, gt)
                cont = True
                break
        if cont:
            continue
        
        res_type[i] = (-1, pred[0], gt)
            
        
            
    print(f"Top1 Correct: {top1}/{n}")
    print(f"Top5 Correct: {top5}/{n}")
    print(f"Num Invalid: {invalid}/{n}")
    print(counts)
    enable_warnings()
    return res_type


        

In [25]:
cot_path = f'../../results/cot_prediction/{model}/cot_prediction_results.npy'
cot_results = np.load(cot_path, allow_pickle=True).item()

In [26]:
cot_eval_baseline(cot_results)

Num correct: 0/20
Num Invalid: 3/20


[(-1,
  ['[CH:1]([CH3:2])([CH3:3])[NH:4][CH:5]([CH3:6])[CH3:7].[c:13]1([C:19](=[CH:20][CH2:21][C:22](=[O:23])[OH:24])[CH3:25])[cH:14][cH:15][cH:16][cH:17][cH:18]1'],
  '[CH:1]([CH3:2])([CH3:3])[CH:21]([CH:20]=[C:19]([c:13]1[cH:14][cH:15][cH:16][cH:17][cH:18]1)[CH3:25])[C:22](=[O:23])[OH:24]'),
 (-1,
  ['[CH:1]1([CH3:15])[CH2:2][CH:3]([O:10][CH2:11][C:12](=[O:13])[OH:14])[CH:4]([CH:7]([CH3:8])[CH3:9])[CH2:5][CH2:6]1.[Cl:16][C:17]([C:18]([Cl:19])=[O:20])=[O:21]'],
  '[CH:1]1([CH3:15])[CH2:2][CH:3]([O:10][CH2:11][C:12](=[O:13])[Cl:16])[CH:4]([CH:7]([CH3:8])[CH3:9])[CH2:5][CH2:6]1'),
 (-2,
  None,
  '[O:1]=[S:2]1(=[O:30])[N:3]=[C:4]([c:12]2[c:13](=[O:29])[n:14]([NH:23][CH2:24][CH2:25][CH2:26][CH2:27][CH3:28])[c:15]3[cH:16][cH:17][cH:18][cH:19][c:20]3[c:21]2[OH:22])[NH:5][c:6]2[c:7]1[cH:8][cH:9][cH:10][cH:11]2'),
 (-1,
  ['Br[c:2]1[s:3][cH:4][cH:5][n:6]1.[s:7]1[c:8]([B:12]([OH:13])[OH:14])[cH:9][cH:10][cH:11]1'],
  '[c:2]1(-[c:8]2[s:7][cH:11][cH:10][cH:9]2)[s:3][cH:4][cH:5][n:6]1'),
 (-1,
 