# Benchmarking Dataset Setup

From macos, I install dropbox desktop app and copy to yuzu:
```bash
while IFS= read -r file; do
  scp "$file" "yuzu:/data/nas-gpu/wang/atong/Datasets/Benchmark"
done < nmr_peaks.txt
```

For each entry, I want to store the following before benchmarking:
```python
entry = {
    'input': {
        'hsqc': ..., # tensor shape (N, 3)
        'c_nmr': ..., # tensor shape (N, 1)
        'h_nmr': ..., # tensor shape (N, 1)
        'mw': ..., # tensor shape (1, 1)
    },
    'smiles': ..., # groundtruth canonical 2d smiles
    'npid': ..., # np id for the molecule
}
```

We need to calculate the following:
```python
stats = {
    'pred_sfp': ..., # predicted fp (l2-normalized, i.e. norm 1)
    'sfp': ..., # sherlock fp for groundtruth canonical 2d smiles
    'mfp': ..., # 2048-bit morgan fp for groundtruth canonical 2d smiles
    'cosine_sim': ..., # cosine similarity of pred fp to groundtruth
    'retrieval_idx': ..., # idx if exists in retrieval set, None otherwise
    'retrievals': {
        k: { # retrieval dict for the k-th retrieval
            'retrieval_idx': ... # retrieval molecule idx
            'retrieval_sfp': ... # retrieval molecule sherlock fp
            'retrieval_mfp': ... # retrieval molecule 2048-bit morgan fp
            'cosine_sim_sfp': ... # cosine similarity to groundtruth in sherlock fp
            'tani_sim_sfp': ... # tanimoto similarity to groundtruth in sherlock fp
            'cosine_sim_mfp': ... # cosine similarity to groundtruth in 2048-bit morgan fp
            'tani_sim_mfp': ... # tanimoto similarity to groundtruth in 2048-bit morgan fp
        }
    },
    'dereplication_top1': ..., # true if dereplicated in top 1, false if not, none if gt does not exist in retrieval set
    'dereplication_top5': ..., # true if dereplicated in top 5, false if not, none if gt does not exist in retrieval set
    'dereplication_top10': ..., # true if dereplicated in top 10, false if not, none if gt does not exist in retrieval set
}
```

In [None]:
import pandas as pd
import os
import torch
from tqdm import tqdm
import csv
import pickle
from rdkit import Chem
from rdkit.Chem import Descriptors
def canonicalize_smiles(smiles: str, keep_stereo: bool = False):
    if smiles is None or smiles == '':
        raise ValueError(f"Invalid empty SMILES")
    if '.' in smiles:
        smiles = max(smiles.split('.'), key=len)
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        raise ValueError(f"Invalid SMILES: {smiles}")
    return Chem.MolToSmiles(mol, isomericSmiles=keep_stereo, canonical=True)

def get_mw(smiles: str):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        raise ValueError(f"Invalid SMILES: {smiles}")
    return Descriptors.ExactMolWt(mol)
BENCHMARK_ROOT = "/data/nas-gpu/wang/atong/Datasets/Benchmark"
smiles = csv.reader(open(os.path.join(BENCHMARK_ROOT, "smiles", "smiles.csv"), 'r'))
smiles_dict = {npid: smiles for _, npid, smiles in smiles}
data = {}
for file in tqdm(os.listdir(os.path.join(BENCHMARK_ROOT, "data"))):
    df = pd.read_excel(os.path.join(BENCHMARK_ROOT, "data", file), sheet_name='1H-center')
    h_nmr = torch.tensor(df['f2 (ppm)'].tolist()).reshape(-1, 1)
    df = pd.read_excel(os.path.join(BENCHMARK_ROOT, "data", file), sheet_name='13C')
    c_nmr = torch.tensor(df['ppm'].tolist()).reshape(-1, 1)
    df = pd.read_excel(os.path.join(BENCHMARK_ROOT, "data", file), sheet_name='HSQC-ME')
    hsqc_c = df['f2 (ppm)'].tolist()
    hsqc_h = df['f1 (ppm)'].tolist()
    intensity = df['Intensity'].tolist()
    hsqc = torch.tensor([hsqc_h, hsqc_c, intensity]).T
    npid = file.rstrip('.xlsx')
    data[len(data)] = {
        'input': {
            'h_nmr': h_nmr,
            'c_nmr': c_nmr,
            'hsqc': hsqc,
            'mw': get_mw(smiles_dict[npid])
        },
        'smiles': canonicalize_smiles(smiles_dict[npid]), # groundtruth canonical 2d smiles
        '3d_smiles': canonicalize_smiles(smiles_dict[npid], keep_stereo=True), # groundtruth canonical 3d smiles
        'npid': npid, # np id for the molecule
    }
pickle.dump(data, open(os.path.join(BENCHMARK_ROOT, "benchmark.pkl"), 'wb'))


100%|██████████| 121/121 [00:02<00:00, 54.88it/s]
