In [1]:
from rdkit import Chem
import pandas as pd
import pickle
import numpy as np
import subprocess
import datetime
import os
from glob import glob


In [2]:

# ──────────────────────────────────────────────────────────────
# Unique folder generator for temporary dataset directories
def unique_dir_name():
    now = datetime.datetime.now()
    return str(now.strftime("%d-%m-%Y_%H-%M-%S"))

# ──────────────────────────────────────────────────────────────
# KPGT embedding function (replaces RDKit fingerprinting)
def smiles_to_embeddings(smiles, gpu):
    folder = unique_dir_name()
    dataset_path = f'/home/malves/predator/KPGT/datasets/{folder}/'
    os.makedirs(dataset_path)

    df = pd.DataFrame({'Class': [0]*len(smiles), 'smiles': smiles})
    df.to_csv(f'{dataset_path}{folder}.csv', index=False)

    original_path = os.getcwd()
    script_path = '/home/malves/predator/KPGT/scripts/preprocess_downstream_dataset.py'
    os.chdir(os.path.dirname(script_path))

    try:
        subprocess.run([
            '/home/malves/miniconda3/envs/KPGT/bin/python', script_path,
            '--data_path', '/home/malves/predator/KPGT/datasets',
            '--dataset', folder
        ])
        print('🧠 Extracting features...')
        subprocess.run([
            '/home/malves/miniconda3/envs/KPGT/bin/python',
            '/home/malves/predator/KPGT/scripts/extract_features.py',
            '--config', 'base',
            '--model_path', '/home/malves/predator/KPGT/models/pretrained/base/base.pth',
            '--data_path', '/home/malves/predator/KPGT/datasets/',
            '--gpu', str(gpu),
            '--dataset', folder
        ])
    finally:
        os.chdir(original_path)

    data = np.load(f'{dataset_path}/kpgt_base.npz')
    fps_array = data['fps']

    os.system(f'rm -r {dataset_path}')
    return fps_array

# ──────────────────────────────────────────────────────────────
# Main embedding computation for a dataset split
def compute_kpgt_embeddings_for_dataset(csv_paths, output_fp_cache_path, gpu=6):
    all_smiles = set()
    for path in csv_paths:
        df = pd.read_csv(path)
        all_smiles.update(df["smiles"])
    all_smiles = sorted(list(all_smiles))  # order for stable mapping

    print(f"🧬 Total unique SMILES: {len(all_smiles)}")
    fps_array = smiles_to_embeddings(all_smiles, gpu=gpu)
    smiles_to_fp = {smi: fps_array[i] for i, smi in enumerate(all_smiles)}

    with open(output_fp_cache_path, "wb") as f:
        pickle.dump(smiles_to_fp, f)

    print(f"✅ Embeddings saved to: {output_fp_cache_path}")

In [3]:
# RAW
root = "/home/malves/predinhib_mtb/data/cv/raw_h37rv_nr/folds/raw"
csv_paths = glob(os.path.join(root, "raw_*.csv"))
print("📄 Found", len(csv_paths), "CSV files for RAW")
compute_kpgt_embeddings_for_dataset(csv_paths, os.path.join(root, "kpgt_embeddings_cache.pkl"), gpu=6)

📄 Found 25 CSV files for RAW
🧬 Total unique SMILES: 18780


Using backend: pytorch
[Parallel(n_jobs=32)]: Using backend LokyBackend with 32 concurrent workers.


constructing graphs


Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
[Parallel(n_jobs=32)]: Done 136 tasks      | elapsed:    3.4s
[Parallel(n_jobs=32)]: Done 936 tasks      | elapsed:    5.1s
[Parallel(n_jobs=32)]: Done 2336 tasks      | elapsed:    7.0s
[Parallel(n_jobs=32)]: Done 4136 tasks      | elapsed:    9.5s
[Parallel(n_jo

saving graphs
extracting fingerprints
saving fingerprints
extracting molecular descriptors
🧠 Extracting features.../18780


Using backend: pytorch


The extracted features were saved at /home/malves/predator/KPGT/datasets//16-04-2025_00-47-00/kpgt_base.npz
✅ Embeddings saved to: /home/malves/predinhib_mtb/data/cv/raw_h37rv_nr/folds/raw/kpgt_embeddings_cache.pkl


In [4]:
# H37Rv
root = "/home/malves/predinhib_mtb/data/cv/raw_h37rv_nr/folds/h37rv"
csv_paths = glob(os.path.join(root, "h37rv_*.csv"))
print("📄 Found", len(csv_paths), "CSV files for H37Rv")
compute_kpgt_embeddings_for_dataset(csv_paths, os.path.join(root, "kpgt_embeddings_cache.pkl"), gpu=6)

📄 Found 25 CSV files for H37Rv
🧬 Total unique SMILES: 14187


Using backend: pytorch
[Parallel(n_jobs=32)]: Using backend LokyBackend with 32 concurrent workers.


constructing graphs


Using backend: pytorch
Using backend: pytorch
Using backend: pytorchUsing backend: pytorch

Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorchUsing backend: pytorch

Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
[Parallel(n_jobs=32)]: Done 136 tasks      | elapsed:    3.5s
[Parallel(n_jobs=32)]: Done 744 tasks      | elapsed:    4.8s
[Parallel(n_jobs=32)]: Done 2144 tasks      | elapsed:    6.8s
[Parallel(n_jobs=32)]: Done 3944 tasks      | elapsed:    9.3s
[Parallel(n_jo

saving graphs
extracting fingerprints
saving fingerprints
extracting molecular descriptors
🧠 Extracting features.../14187


Using backend: pytorch


The extracted features were saved at /home/malves/predator/KPGT/datasets//16-04-2025_01-21-09/kpgt_base.npz
✅ Embeddings saved to: /home/malves/predinhib_mtb/data/cv/raw_h37rv_nr/folds/h37rv/kpgt_embeddings_cache.pkl


In [5]:
# NR
root = "/home/malves/predinhib_mtb/data/cv/raw_h37rv_nr/folds/nr"
csv_paths = glob(os.path.join(root, "nr_*.csv"))
print("📄 Found", len(csv_paths), "CSV files for NR")
compute_kpgt_embeddings_for_dataset(csv_paths, os.path.join(root, "kpgt_embeddings_cache.pkl"), gpu=6)


📄 Found 25 CSV files for NR
🧬 Total unique SMILES: 18402


Using backend: pytorch
[Parallel(n_jobs=32)]: Using backend LokyBackend with 32 concurrent workers.


constructing graphs


Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
[Parallel(n_jobs=32)]: Done 162 tasks      | elapsed:    3.3s
[Parallel(n_jobs=32)]: Done 1032 tasks      | elapsed:    4.8s
[Parallel(n_jobs=32)]: Done 2432 tasks      | elapsed:    6.9s
[Parallel(n_jobs=32)]: Done 4232 tasks      | elapsed:    9.8s
[Parallel(n_j

saving graphs
extracting fingerprints
saving fingerprints
extracting molecular descriptors
🧠 Extracting features.../18402


Using backend: pytorch


The extracted features were saved at /home/malves/predator/KPGT/datasets//16-04-2025_01-36-42/kpgt_base.npz
✅ Embeddings saved to: /home/malves/predinhib_mtb/data/cv/raw_h37rv_nr/folds/nr/kpgt_embeddings_cache.pkl
