# Dataset migration from v3 to Mnova ablation studies

Going to remove 1d data and replace it with Mnova simulated data

`v4`: Same data, just with Mnova 1d data replacing the existing train/val/test

`v4_large`: Train on whole retrieval set with Mnova simulations

### Step 1: Cleaning Dataset for Mnova Simulation

There's a lot of repeated SMILES strings, we want to simulate only once and also clear out anything with huge molecular weight (>1800 Da)

Statistics:
- 216,586 entries in train/val/test
- 190,164 unique SMILES across train/val/test
- 189,691 unique SMILES kept for simulation (<=1800 Da)
- 182,637 unique SMILES kept for simulation (<=1000 Da)

In [1]:
from rdkit import Chem
from rdkit.Chem import Descriptors

def check_invalid_mol(smiles: str) -> bool:
    """
    Returns True if the molecule described by `smiles` has
    molecular weight > 1800 Da or is invalid. Returns False otherwise.
    """
    if not smiles:
        return True

    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return True

    mw = Descriptors.MolWt(mol)
    return mw > 1000.0

In [7]:
import pickle
import os
from tqdm import tqdm

DATASET_ROOT = '/data/nas-gpu/wang/atong/MoonshotDatasetv3'
OUTPUT_ROOT = '/data/nas-gpu/wang/atong/MoonshotDatasetv4/WorkingDir'
os.makedirs(OUTPUT_ROOT, exist_ok=True)

index: dict[int, dict] = pickle.load(open(os.path.join(DATASET_ROOT, 'index.pkl'), 'rb'))
all_smiles: list[str] = list(set([e['smiles'] for e in index.values()]))
valid_smiles: list[str] = [s for s in tqdm(all_smiles) if not check_invalid_mol(s)]

valid_smiles_path = os.path.join(OUTPUT_ROOT, 'smiles_1000.txt')
with open(valid_smiles_path, 'w') as f:
    f.write('\n'.join(valid_smiles))

print(f'Wrote {len(valid_smiles)}/{len(index)} valid SMILES to {valid_smiles_path}')

100%|██████████| 190164/190164 [00:18<00:00, 10218.07it/s]


Wrote 182637/216586 valid SMILES to /data/nas-gpu/wang/atong/MoonshotDatasetv4/WorkingDir/smiles_1000.txt


In [None]:
import pickle
import os
from tqdm import tqdm

DATASET_ROOT = '/data/nas-gpu/wang/atong/Datasets/MoonshotDatasetv3'
OUTPUT_ROOT = '/data/nas-gpu/wang/atong/Datasets/MoonshotDatasetv4/WorkingDir'
os.makedirs(OUTPUT_ROOT, exist_ok=True)

index: dict[int, dict] = pickle.load(open(os.path.join(DATASET_ROOT, 'retrieval.pkl'), 'rb'))
all_smiles: list[str] = list(set([e['smiles'] for e in index.values()]))
valid_smiles: list[str] = [s for s in tqdm(all_smiles) if not check_invalid_mol(s)]

valid_smiles_path = os.path.join(OUTPUT_ROOT, 'smiles_1000.txt')
with open(valid_smiles_path, 'w') as f:
    f.write('\n'.join(valid_smiles))

print(f'Wrote {len(valid_smiles)}/{len(index)} valid SMILES to {valid_smiles_path}')

# JSONL Format for MoonshotDataset
```json
{
    "idx": 0,
    "smiles": "C=CC1CN2CCC1CC2C(O)c1ccnc2ccc(OC)cc12",
    "split": "train",
    "has_hsqc": true,
    "has_c_nmr": false,
    "has_h_nmr": false,
    "has_mass_spec": true,
    "has_iso_dist": true,
    "mw": 404.1099,
    "name": "quinine hydrobromide",
    "has_mw": true,
    "formula": "C20H25BrN2O2",
    "has_formula": true,
    "np_pathway": ["Alkaloids"],
    "np_superclass": ["Tryptophan alkaloids"],
    "np_class": [],
    "hsqc": [
        [54.89, 3.077, -1.0], 
        ...
    ],
    "mass_spec": [
        [93.06987762451172, 1.3903098106384277], 
        ...
    ],
    "c_nmr": [
        2.130000114440918, 2.25, ...
    ],
    "h_nmr": [
        2.130000114440918, 2.25, ...
    ],
    "fragidx": [
        1, 3, 4, 5, 6, 8, 9, 14, 15, 17, 19, 20, 29, 32, 35, 46, 50, 59, 60, 62, 69, 80, 81, 125, 184, 210, 240, 378, 382, 392, 472, 491, 698, 891, 933, 1439, 1486, 1639, 2185, 2792, 5479, 5610, 6182, 7318, 7903, 11697, 14346
    ]
}
```

# JSONL Format for Predictions

```json
{
    "idx":0,
    "smiles":"COc1ccc2cc1Oc1cc(ccc1O)C(O)C13SSSC4(C(=O)N1C)C(O)C1=COC=CC(OC2=O)C1N4C3=O",
    "status":"SUCCESS",
    "error":null,
    "atoms":[
        {"number":"1","name":"CH3"},
        {"number":"2","name":"O"},
        ...
    ],
    "predictions":{
        "hsqc":{
            "status":"SUCCESS",
            "H":[
                {
                    "atom":[{"index":1}],
                    "shift":{"value":3.9237684493895095,"error":0.1},
                    "js":[{"atom":[{"index":1}],"j":{"value":5.46,"error":1.72}}]
                },
                {
                    "atom":[{"index":4}],
                    "shift":{"value":7.14351619175921,"error":0.1},
                    "js":[{"atom":[{"index":5}],"j":{"value":8.65,"error":0.31}},{"atom":[{"index":7}],"j":{"value":0.1,"error":0.1}}]
                },
                {
                    "atom":[{"index":5}],
                    "shift":{"value":7.5815223902024345,"error":0.1},
                    "js":[{"atom":[{"index":4}],"j":{"value":8.65,"error":0.31}},{"atom":[{"index":7}],"j":{"value":2.03,"error":0.48}}]
                },
                ...
            ],
            "C":[
                {
                    "atom":[{"index":1}],
                    "shift":{"value":56.0195331969619,"error":3},
                    "js":[{"atom":[{"index":1}],"j":{"value":143.96,"error":3.91}},{"atom":[{"index":4}],"j":{"value":0.4,"error":1.18}}]
                },
                {
                    "atom":[{"index":3}],
                    "shift":{"value":154.3666380106334,"error":3},
                    "js":[{"atom":[{"index":1}],"j":{"value":3.97,"error":12.36}},{"atom":[{"index":4}],"j":{"value":2.52,"error":4.27}},{"atom":[{"index":5}],"j":{"value":7.17,"error":2.84}},{"atom":[{"index":7}],"j":{"value":6.43,"error":2.84}}]
                },
                ...
            ],
            "error":null
        }
    }
}
```

# Dataset Forms

We will reduce the dataset to the following forms:

Form 1

- **MARINABase1**:
    - MoonshotDatasetv3 without any molecules that errored during predictions and filter all molecules <= 1000 Da
- **MARINADataset1**:
    - MARINABase1, with replacing all existing C/H NMRs
- **MARINADataset2**:
    - MARINABase1, put all C/H simulated NMR that exist
- **MARINADataset3**:
    - MARINABase1, replacing all existing C/H/HSQC NMRs
- **MARINADataset4**:
    - MARINABase1, put all C/H/HSQC simulated NMR that exist

Form 2

- **MARINABase2**:
    - MoonshotDatasetv3 without any molecules that errored during predictions and filter all molecules within [100Da, 1000Da]
- **MARINAMedDataset1**:
    - MARINABase2, with replacing all existing C/H NMRs
- **MARINAMedDataset2**:
    - MARINABase2, put all C/H simulated NMR that exist
- **MARINAMedDataset3**:
    - MARINABase2, replacing all existing C/H/HSQC NMRs
- **MARINAMedDataset4**:
    - MARINABase2, put all C/H/HSQC simulated NMR that exist

Form 3

- **MARINABaseNoDup**:
    - MoonshotDatasetv3 without any molecules that errored during predictions and filter all molecules <= 1000 Da, and also no duplicate SMILES


In [None]:
import glob
import os
import json
from tqdm import tqdm
PRED_ROOT = '/data/nas-gpu/wang/atong/Datasets/MnovaPredictions/raw'
files = sorted(glob.glob(os.path.join(PRED_ROOT, '*.jsonl')))
def process_file(file):
    nmrs = {}
    with open(file, 'r') as f:
        for line in f:
            data = json.loads(line)
            if data['predictions']['hsqc']['status'] != 'SUCCESS':
                continue
            h_nmr = data['predictions']['hsqc']['H']
            c_nmr = data['predictions']['hsqc']['C']
            if h_nmr is None or c_nmr is None or len(h_nmr) == 0 or len(c_nmr) == 0:
                continue
            nmrs[data['smiles']] = {
                'h_nmr': h_nmr,
                'c_nmr': c_nmr,
                'atoms': data['atoms']
            }
    return nmrs

all_nmrs = {}
for file in tqdm(files):
    nmrs = process_file(file)
    all_nmrs.update(nmrs)


100%|██████████| 183/183 [02:12<00:00,  1.38it/s]


In [None]:
from typing import Any, Dict, List, Tuple
nmr_data = {}

def _atom_sign_from_name(atom_name: str) -> int:
    return -1 if "CH2" in atom_name else +1

def assemble_nmr_data(preds: Dict[str, Any]) -> Dict[str, List]:
    data = {
        "h_nmr": [],
        "c_nmr": [],
        "hsqc": [],
        "h_nmr_error": [],
        "c_nmr_error": [],
        "hsqc_error": [],
    }

    atom_name_by_idx: Dict[int, str] = {}
    for a in preds['atoms']:
        idx = int(a["number"])
        atom_name_by_idx[idx] = a['name']

    c_by_atom: Dict[int, Tuple[float, float]] = {}
    for c in preds['c_nmr']:
        for atom in c['atom']:
            atom_idx = atom['index']
            c_shift = float(c['shift']['value'])
            c_err = float(c['shift']['error'])
            c_by_atom[int(atom_idx)] = (c_shift, c_err)
            data["c_nmr"].append(c_shift)
            data["c_nmr_error"].append(c_err)

    for h in preds['h_nmr']:
        for atom in h['atom']:
            atom_idx = atom['index']
            h_shift = float(h['shift']['value'])
            h_err = float(h['shift']['error'])
            data["h_nmr"].append(h_shift)
            data["h_nmr_error"].append(h_err)

            if atom_idx not in c_by_atom:
                if atom_name_by_idx[atom_idx] in ('CH', 'CH2', 'CH3'):
                    raise ValueError()
                continue
            c_shift, c_err = c_by_atom[atom_idx]
            sign = _atom_sign_from_name(atom_name_by_idx[atom_idx])

            data["hsqc"].append([c_shift, h_shift, sign])
            data["hsqc_error"].append([c_err, h_err, 0.0])

    return data

In [29]:
for smiles, nmr in tqdm(all_nmrs.items()):
    try:
        nmr_data[smiles] = assemble_nmr_data(nmr)
    except ValueError:
        print(f'Error assembling NMR data for {smiles}')
        raise


100%|██████████| 182619/182619 [00:06<00:00, 27428.17it/s]


In [30]:
nmr_data['CN1C2CCC1CC(OC(=O)C(O)c1ccccc1)C2']

{'h_nmr': [2.3499928105861767,
  3.1126957067399488,
  1.6531679627529237,
  1.9031679627529232,
  1.6531679627529237,
  1.9031679627529232,
  3.1126957067399488,
  1.8494863922886802,
  2.09948639228868,
  4.992714033328051,
  5.071375826766051,
  4.184552785568564,
  7.431014308767889,
  7.344257177388522,
  7.302689520964752,
  7.344257177388522,
  7.431014308767889,
  1.8494863922886802,
  2.09948639228868],
 'c_nmr': [39.15912663386418,
  59.17920115876122,
  27.972929656346743,
  27.972929656346743,
  59.17920115876122,
  35.69239515668891,
  68.21024509912226,
  173.3923212900274,
  73.56594809449798,
  138.03578026240166,
  127.19551696608909,
  128.53111673852533,
  128.43539909207226,
  128.53111673852533,
  127.19551696608909,
  35.69239515668891],
 'hsqc': [[39.15912663386418, 2.3499928105861767, 1],
  [59.17920115876122, 3.1126957067399488, 1],
  [27.972929656346743, 1.6531679627529237, -1],
  [27.972929656346743, 1.9031679627529232, -1],
  [27.972929656346743, 1.653167962