In [1]:
import json
from tqdm import tqdm
def process_file(file):
    nmrs = {}
    with open(file, 'r') as f:
        for line in f:
            data = json.loads(line)
            if data['predictions']['hsqc']['status'] != 'SUCCESS':
                continue
            h_nmr = data['predictions']['hsqc']['H']
            c_nmr = data['predictions']['hsqc']['C']
            if h_nmr is None or c_nmr is None or len(h_nmr) == 0 or len(c_nmr) == 0:
                continue
            nmrs[data['smiles']] = {
                'h_nmr': h_nmr,
                'c_nmr': c_nmr,
                'atoms': data['atoms']
            }
    return nmrs

all_nmrs = process_file('/data/nas-gpu/wang/atong/Datasets/Benchmark/benchmark_sim_3d.jsonl')


In [2]:
from typing import Any, Dict, List, Tuple
nmr_data = {}

def _atom_sign_from_name(atom_name: str) -> int:
    return -1 if "CH2" in atom_name else +1

def assemble_nmr_data(preds: Dict[str, Any]) -> Dict[str, List]:
    data = {
        "h_nmr": [],
        "c_nmr": [],
        "hsqc": [],
        "h_nmr_error": [],
        "c_nmr_error": [],
        "hsqc_error": [],
    }

    atom_name_by_idx: Dict[int, str] = {}
    for a in preds['atoms']:
        idx = int(a["number"])
        atom_name_by_idx[idx] = a['name']

    c_by_atom: Dict[int, Tuple[float, float]] = {}
    for c in preds['c_nmr']:
        for atom in c['atom']:
            atom_idx = atom['index']
            c_shift = float(c['shift']['value'])
            c_err = float(c['shift']['error'])
            c_by_atom[int(atom_idx)] = (c_shift, c_err)
            data["c_nmr"].append(c_shift)
            data["c_nmr_error"].append(c_err)

    for h in preds['h_nmr']:
        for atom in h['atom']:
            atom_idx = atom['index']
            h_shift = float(h['shift']['value'])
            h_err = float(h['shift']['error'])
            data["h_nmr"].append(h_shift)
            data["h_nmr_error"].append(h_err)

            if atom_idx not in c_by_atom:
                if atom_name_by_idx[atom_idx] in ('CH', 'CH2', 'CH3'):
                    raise ValueError()
                continue
            c_shift, c_err = c_by_atom[atom_idx]
            sign = _atom_sign_from_name(atom_name_by_idx[atom_idx])

            data["hsqc"].append([c_shift, h_shift, sign])
            data["hsqc_error"].append([c_err, h_err, 0.0])

    return data

In [3]:
for smiles, nmr in tqdm(all_nmrs.items()):
    nmr_data[smiles] = assemble_nmr_data(nmr)

100%|██████████| 121/121 [00:00<00:00, 20868.90it/s]


In [4]:
from rdkit import Chem
def canonicalize_smiles(smiles: str, keep_stereo: bool = False):
    if smiles is None or smiles == '':
        raise ValueError(f"Invalid empty SMILES")
    if '.' in smiles:
        smiles = max(smiles.split('.'), key=len)
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        raise ValueError(f"Invalid SMILES: {smiles}")
    return Chem.MolToSmiles(mol, isomericSmiles=keep_stereo, canonical=True)

In [5]:
import torch
benchmark_data = {
    idx: {
        'input': {
            'h_nmr': torch.tensor(nmr_data[smiles]['h_nmr']).reshape(-1, 1),
            'c_nmr': torch.tensor(nmr_data[smiles]['c_nmr']).reshape(-1, 1),
            'hsqc': torch.tensor(nmr_data[smiles]['hsqc']),
        },
        'smiles': canonicalize_smiles(smiles),
    }
    for idx, smiles in enumerate(nmr_data.keys())
}

In [6]:
import pickle

pickle.dump(benchmark_data, open('/data/nas-gpu/wang/atong/Datasets/Benchmark/benchmark_sim_3d.pkl', 'wb'))