In [1]:
# add chempe
import sys
sys.path.append('../..')


# Load Data

In [2]:
import numpy as np
import os

In [3]:
# Loads results

model = 'gpt-4'
#model = 'gpt-3.5-turbo'

load_path = lambda test: f"../../results/{test}/{model}/{test}_results.npy"
zero_shot_path = load_path('zero_shot')
few_shot_path = load_path('few_shot')
zero_shot_results = np.load(zero_shot_path, allow_pickle=True).item()
few_shot_results = np.load(few_shot_path, allow_pickle=True).item()
sc_results = np.load(load_path('self_consistency'), allow_pickle=True, ).item()

In [4]:
# Loads indices
test_indices = np.load("../../data/test_indices.npy", allow_pickle=True)
# number of tests
n = 20


# Evaluate Data

In [5]:
import rdkit
from rdkit import Chem
import tqdm
from chempe.data.warnings import disable_warnings, enable_warnings
from collections import Counter


In [57]:
# evaluate results
def eval_baseline(rez, warnings = False):
    """Only 1 result"""
    if not warnings:
        disable_warnings()
        
    top1 = 0
    invalid = 0
    res_type = [0] * n


    for i in range(n):
        # canonical gt
        gt = Chem.CanonSmiles(rez['ground_truth'][test_indices[i]])
        predicted = rez['predicted'][test_indices[i]][0]
        try:
            pred = Chem.CanonSmiles(predicted).split(".")
            if gt in pred:
                top1 += 1
                res_type[i] = (1, pred, gt)
            else:
                res_type[i] = (-1, pred, gt)
        except Exception as e:
            # if failed to parse
            invalid += 1
            res_type[i] = (-2, None, gt)
        
    print(f"Num correct: {top1}/{n}")
    print(f"Num Invalid: {invalid}/{n}")
    
    enable_warnings()
    return res_type
    
def eval_sc(rez, warnings = False):
    """Use consistency"""
    if not warnings:
        disable_warnings()
    
    top1 = 0
    top5 = 0
    invalid = 0
    res_type = [-1] * n
    
    def count_preds(predicted):
        canon_pred = []
        tmp = 0
        for p in predicted:
            try:
                canon_pred.append(Chem.CanonSmiles(p))
            except:
                tmp += 1
                pass
        #print(tmp)
        return Counter(canon_pred)
        

    for i in range(n):
        # canonical gt
        tmp5 = False
        gt = Chem.CanonSmiles(rez['ground_truth'][test_indices[i]])
        counter = count_preds(rez['predicted'][test_indices[i]])
        if len(counter) == 0:
            invalid += 1
            res_type[i] = (-2, None, gt)
            continue

        # top 1
        pred = counter.most_common()[0]
        if gt in pred[0].split():
            top1 += 1
            res_type[i] = (1, pred[0].split(), gt)
            continue
        else:
            res_type[i] = (-1, pred[0].split(), gt)
        
    print(f"Top1 Correct: {top1}/{n}")
    print(f"Num Invalid: {invalid}/{n}")
    
    enable_warnings()
    return res_type


        

In [39]:
# zero shot
zero_shot_res_types = eval_baseline(zero_shot_results)

Num correct: 0/20
Num Invalid: 4/20


In [55]:
# few shot results
few_shot_res_types = eval_baseline(few_shot_results)

Num correct: 2/20
Num Invalid: 3/20


In [59]:
# self-consistency
self_consistency_res_types = eval_sc(sc_results)

10
7
4
15
4
17
12
14
2
13
8
17
5
8
4
8
7
6
5
12
Top1 Correct: 1/20
Num Invalid: 0/20


# Observe results

In [47]:
def filter_res(res_type):
    inc = list(filter(lambda x: x[0] == -1, res_type))
    inv = list(filter(lambda x: x[0] == -2, res_type))
    cor = list(filter(lambda x: x[0] == 1, res_type))
    return cor, inc, inv

Zero Shot

In [48]:
cor, inc, inv = filter_res(zero_shot_res_types)

In [50]:
# none correct
len(cor)

0

In [53]:
# check incorrect
for _, p, gt in inc:
    print(p)
    print(gt)
    print()

['[CH:1]([CH3:2])([CH3:3])[NH:4][CH:5]([CH3:6])[CH3:7]', '[c:13]1([C:19](=[CH:20][CH2:21][C:22](=[O:23])[OH:24])[CH3:25])[cH:14][cH:15][cH:16][cH:17][cH:18]1']
[CH:1]([CH3:2])([CH3:3])[CH:21]([CH:20]=[C:19]([c:13]1[cH:14][cH:15][cH:16][cH:17][cH:18]1)[CH3:25])[C:22](=[O:23])[OH:24]

['[CH:1]1([CH3:15])[CH2:2][CH:3]([O:10][CH2:11][C:12](=[O:13])[OH:14])[CH:4]([CH:7]([CH3:8])[CH3:9])[CH2:5][CH2:6]1', '[Cl:16][C:17]([C:18](=[O:20])[O:21])=[O:20]']
[CH:1]1([CH3:15])[CH2:2][CH:3]([O:10][CH2:11][C:12](=[O:13])[Cl:16])[CH:4]([CH:7]([CH3:8])[CH3:9])[CH2:5][CH2:6]1

['[O:1]=[S:2]1(=[O:30])[N:3]=[C:4]([c:12]2[c:13](=[O:29])[n:14]([N:23]=[CH:24][CH2:25][CH2:26][CH2:27][CH3:28])[c:15]3[cH:16][cH:17][cH:18][cH:19][c:20]3[c:21]2[OH:22])[NH:5][c:6]2[c:7]1[cH:8][cH:9][cH:10][cH:11]2']
[O:1]=[S:2]1(=[O:30])[N:3]=[C:4]([c:12]2[c:13](=[O:29])[n:14]([NH:23][CH2:24][CH2:25][CH2:26][CH2:27][CH3:28])[c:15]3[cH:16][cH:17][cH:18][cH:19][c:20]3[c:21]2[OH:22])[NH:5][c:6]2[c:7]1[cH:8][cH:9][cH:10][cH:11]2

['[Br:

Few Shot

In [60]:
cor, inc, inv = filter_res(few_shot_res_types)

In [61]:
# cor
for (_, p, gt) in cor:
    print(p)
    print(gt)
    print()

['[O:1]=[S:2]1(=[O:30])[N:3]=[C:4]([c:12]2[c:13](=[O:29])[n:14]([NH:23][CH2:24][CH2:25][CH2:26][CH2:27][CH3:28])[c:15]3[cH:16][cH:17][cH:18][cH:19][c:20]3[c:21]2[OH:22])[NH:5][c:6]2[c:7]1[cH:8][cH:9][cH:10][cH:11]2']
[O:1]=[S:2]1(=[O:30])[N:3]=[C:4]([c:12]2[c:13](=[O:29])[n:14]([NH:23][CH2:24][CH2:25][CH2:26][CH2:27][CH3:28])[c:15]3[cH:16][cH:17][cH:18][cH:19][c:20]3[c:21]2[OH:22])[NH:5][c:6]2[c:7]1[cH:8][cH:9][cH:10][cH:11]2

['[CH2:2]([C:3]#[C:4][CH3:5])[O:25][c:22]1[cH:21][cH:20][c:19]([S:16]([N:15]2[CH:10]([C:8]([O:7][CH3:6])=[O:9])[CH:11]([CH3:27])[O:12][CH:13]([CH3:26])[CH2:14]2)(=[O:17])=[O:18])[cH:24][cH:23]1']
[CH2:2]([C:3]#[C:4][CH3:5])[O:25][c:22]1[cH:21][cH:20][c:19]([S:16]([N:15]2[CH:10]([C:8]([O:7][CH3:6])=[O:9])[CH:11]([CH3:27])[O:12][CH:13]([CH3:26])[CH2:14]2)(=[O:17])=[O:18])[cH:24][cH:23]1



In [62]:
# cor
for (_, p, gt) in inc:
    print(p)
    print(gt)
    print()

['[CH:1]([CH3:2])([CH3:3])[NH:4][CH:5]([CH3:6])[CH3:7]', '[c:13]1([C:19](=[CH:20][CH2:21][C:22](=[O:23])[OH:24])[CH3:25])[cH:14][cH:15][cH:16][cH:17][cH:18]1']
[CH:1]([CH3:2])([CH3:3])[CH:21]([CH:20]=[C:19]([c:13]1[cH:14][cH:15][cH:16][cH:17][cH:18]1)[CH3:25])[C:22](=[O:23])[OH:24]

['[CH:1]1([CH3:15])[CH2:2][CH:3]([O:10][CH2:11][C:12](=[O:13])[O:21][C:17]([Cl:16])[C:18]([Cl:19])=[O:20])[CH:4]([CH:7]([CH3:8])[CH3:9])[CH2:5][CH2:6]1']
[CH:1]1([CH3:15])[CH2:2][CH:3]([O:10][CH2:11][C:12](=[O:13])[Cl:16])[CH:4]([CH:7]([CH3:8])[CH3:9])[CH2:5][CH2:6]1

['[c:2]1([O:20][C:18]([O:17])=[O:19])[s:3][cH:4][cH:5][n:6]1']
[c:2]1(-[c:8]2[s:7][cH:11][cH:10][cH:9]2)[s:3][cH:4][cH:5][n:6]1

['[CH2:1]([c:2]1[cH:3][cH:4][cH:5][cH:6][cH:7]1)[O:8][CH:9]1[CH:10]([CH:52]=[O:53])[O:11][C:12]([O:31][CH3:32])([c:33]2[cH:34][c:35]([CH2:40][c:41]3[c:42]([F:51])[c:43]([F:50])[c:44]([O:47][CH2:48][CH3:49])[cH:45][cH:46]3)[c:36]([Cl:39])[cH:37][cH:38]2)[CH:13]([O:23][CH2:24][c:25]2[cH:26][cH:27][cH:28][cH:29][cH:30]2

Self Consistency

In [67]:
cor, inc, inv = filter_res(self_consistency_res_types)

In [68]:
# cor
for (_, p, gt) in cor:
    print(p)
    print(gt)
    print()

['[c:2]1([NH:17][c:18]2[cH:19][c:20]([NH:24][S:25](=[O:26])(=[O:27])[CH2:28][c:29]3[cH:30][cH:31][cH:32][cH:33][cH:34]3)[cH:21][cH:22][cH:23]2)[o:3][c:4](-[c:7]2[cH:8][cH:9][c:10]([C:13]([F:14])([F:15])[F:16])[cH:11][cH:12]2)[cH:5][n:6]1']
[c:2]1([NH:17][c:18]2[cH:19][c:20]([NH:24][S:25](=[O:26])(=[O:27])[CH2:28][c:29]3[cH:30][cH:31][cH:32][cH:33][cH:34]3)[cH:21][cH:22][cH:23]2)[o:3][c:4](-[c:7]2[cH:8][cH:9][c:10]([C:13]([F:14])([F:15])[F:16])[cH:11][cH:12]2)[cH:5][n:6]1



In [69]:
# inc
for (_, p, gt) in inc:
    print(p)
    print(gt)
    print()

['[CH:1]([CH3:2])([CH3:3])[NH:4][CH:5]([CH3:6])[CH3:7].[c:13]1([C:19](=[CH:20][CH2:21][C:22](=[O:23])[OH:24])[CH3:25])[cH:14][cH:15][cH:16][cH:17][cH:18]1']
[CH:1]([CH3:2])([CH3:3])[CH:21]([CH:20]=[C:19]([c:13]1[cH:14][cH:15][cH:16][cH:17][cH:18]1)[CH3:25])[C:22](=[O:23])[OH:24]

['[CH:1]1([CH3:15])[CH2:2][CH:3]([O:10][CH2:11][C:12](=[O:13])[OH:14])[CH:4]([C:17]([C:18]([Cl:19])=[O:20])=[O:21])[CH2:5][CH2:6]1']
[CH:1]1([CH3:15])[CH2:2][CH:3]([O:10][CH2:11][C:12](=[O:13])[Cl:16])[CH:4]([CH:7]([CH3:8])[CH3:9])[CH2:5][CH2:6]1

['[O:1]=[S:2]1(=[O:30])[N:3]=[C:4]([c:12]2[c:13](=[O:29])[n:14]([NH2:23])[c:15]3[cH:16][cH:17][cH:18][cH:19][c:20]3[c:21]2[OH:22])[NH:5][c:6]2[c:7]1[cH:8][cH:9][cH:10][cH:11]2']
[O:1]=[S:2]1(=[O:30])[N:3]=[C:4]([c:12]2[c:13](=[O:29])[n:14]([NH:23][CH2:24][CH2:25][CH2:26][CH2:27][CH3:28])[c:15]3[cH:16][cH:17][cH:18][cH:19][c:20]3[c:21]2[OH:22])[NH:5][c:6]2[c:7]1[cH:8][cH:9][cH:10][cH:11]2

['[c:2]1([O:25][CH2:21][CH2:22][O:23][CH3:24])[s:3][cH:4][cH:5][n:6]1']
[c:2]1(