In [18]:
# add chempe
import sys
sys.path.append('../..')


In [19]:
import numpy as np
import os

# Step 1 Prediction

In [20]:
import rdkit
from rdkit import Chem
import tqdm
from chempe.data.warnings import disable_warnings, enable_warnings
from collections import Counter


In [54]:
#model = 'gpt-4'
model = 'gpt-4'

s1_pred_path = f'../../results/step1_prediction/{model}/step1_prediction_results.npy'
s1_results = np.load(s1_pred_path, allow_pickle=True).item()
s1_pred_sc_path = f'../../results/step1_prediction_sc/{model}/step1_prediction_sc_results.npy'
s1_sc_results = np.load(s1_pred_sc_path, allow_pickle=True).item()

In [55]:
# Loads indices
test_indices = np.load("../../data/test_indices.npy", allow_pickle=True)
# number of tests
n = 20

In [56]:
# evaluate results
def eval_baseline(rez, warnings = False):
    """Only 1 result"""
    if not warnings:
        disable_warnings()
        
    top1 = 0
    invalid = 0
    res_type = [0] * n


    for i in range(n):
        # canonical gt
        gt = Chem.CanonSmiles(rez['ground_truth'][test_indices[i]])
        predicted = rez['predicted'][test_indices[i]][0]
        try:
            pred = Chem.CanonSmiles(predicted)
            if gt == pred:
                top1 += 1
                res_type[i] = (1, pred, gt)
            else:
                res_type[i] = (-1, pred, gt)
        except Exception as e:
            # if failed to parse
            invalid += 1
            res_type[i] = (-2, None, gt)
        
    print(f"Num correct: {top1}/{n}")
    print(f"Num Invalid: {invalid}/{n}")
    
    enable_warnings()
    return res_type
    
def eval_sc(rez, warnings = False):
    """Use consistency"""
    if not warnings:
        disable_warnings()
    
    top1 = 0
    top5 = 0
    invalid = 0
    res_type = [-1] * n
    counts = []
    
    def count_preds(predicted):
        canon_pred = []
        tmp = 0
        for p in predicted:
            try:
                canon_pred.append(Chem.CanonSmiles(p))
            except:
                tmp += 1
                pass
        #print(tmp)
        return Counter(canon_pred)
        

    for i in range(n):
        # canonical gt
        gt = Chem.CanonSmiles(rez['ground_truth'][test_indices[i]])
        counter = count_preds(rez['predicted'][test_indices[i]])
        if len(counter) == 0:
            invalid += 1
            res_type[i] = (-2, None, gt)
            continue

        # top 1
        pred = counter.most_common()[0]
        if gt == pred[0]:
            counts.append((pred[1],))
            top1 += 1
            top5 += 1
            res_type[i] = (1, pred[0], gt)
            continue
        
        pred5 = counter.most_common(n = 5)
        pred5, cs = list(zip(*pred5))
        if gt in pred5:
            counts.append(cs)
            top5 += 1
            res_type[i] = (5, pred5, gt)
            continue
        
        res_type[i] = (-1, pred[0], gt)
            
        
            
    print(f"Top1 Correct: {top1}/{n}")
    print(f"Top5 Correct: {top5}/{n}")
    print(f"Num Invalid: {invalid}/{n}")
    print(counts)
    enable_warnings()
    return res_type


        

In [57]:
_= eval_baseline(rez=s1_results)

Num correct: 18/20
Num Invalid: 0/20


In [58]:
_ = eval_sc(s1_sc_results)

Top1 Correct: 20/20
Top5 Correct: 20/20
Num Invalid: 0/20
[(20,), (20,), (20,), (20,), (20,), (18,), (20,), (20,), (19,), (20,), (12,), (20,), (17,), (20,), (7,), (20,), (20,), (20,), (20,), (20,)]
