# Search suitable datasets

In [1]:
import os
import pandas as pd

interested_assay = 'IC50_in_nM'
least_sample_num = 1600
dataset_dir = '/home1/yueming/Drug_Discovery/Datasets/'
for target in os.listdir(dataset_dir):
    target_path = os.path.join(dataset_dir, target)
    for assay in os.listdir(target_path):
        if assay == interested_assay:
            assay_path = os.path.join(target_path, assay)
            for data_file in os.listdir(assay_path):
                if data_file.split('.')[0][-3:] == 'all':
                    df_path = os.path.join(assay_path, data_file)
                    df = pd.read_csv(df_path)
                    df_len = len(df)
                    if df_len >= least_sample_num:
                        print(target, assay, df_len)

FileNotFoundError: [Errno 2] No such file or directory: '/home1/yueming/Drug_Discovery/Datasets/'

In [None]:
import os
import pandas as pd

interested_assay = 'IC50_in_nM'
least_sample_num, largest_sample_num = 1600, 2000
dataset_dir = '/home1/yueming/Drug_Discovery/Datasets/'
for target in os.listdir(dataset_dir):
    target_path = os.path.join(dataset_dir, target)
    for assay in os.listdir(target_path):
        if assay == interested_assay:
            assay_path = os.path.join(target_path, assay)
            for data_file in os.listdir(assay_path):
                if data_file.split('.')[0][-3:] == 'all':
                    df_path = os.path.join(assay_path, data_file)
                    df = pd.read_csv(df_path)
                    df_len = len(df)
                    if df_len >= least_sample_num and df_len <= largest_sample_num:
                        print(target, assay, df_len)

# Standardize datasets

In [None]:
import pandas as pd
from tqdm import tqdm
import numpy as np
import math

# target id: [assay type, #compounds, #binding sites, #pockets, UniPort id, potein name (all human)]
target_assay_dict = {'CHEMBL3820': ['EC50_in_nM', 997, 11, None, 'P35557', 'Hexokinase-4'], 
                     'CHEMBL4422': ['EC50_in_nM', 1693, 2, None, 'O14842', 'Free fatty acid receptor 1'], 
                     'CHEMBL235': ['EC50_in_nM', 3611, 4, None, 'P37231', 'Peroxisome proliferator-activated receptor gamma'],
                     'CHEMBL202': ['IC50_in_nM', 957, 8, 7, 'P00374','Dihydrofolate reductase'], 
                     'CHEMBL3976': ['IC50_in_nM', 1642, 3, None, 'Q9UHL4','Dipeptidyl peptidase 2'], 
                     'CHEMBL333': ['IC50_in_nM', 3686, 24, None, 'P08253','72 kDa type IV collagenase'], 
                     'CHEMBL2971': ['IC50_in_nM', 6207, 8, None, 'O60674','JAK2_HUMAN'], 
                     'CHEMBL279': ['IC50_in_nM', 9573, 4, None, 'P35968','Vascular endothelial growth factor receptor 2']}
task_name_convert_dict = {'EC50_in_nM': 'pEC50', 'IC50_in_nM': 'pIC50'}
row_activity_root = '/home1/yueming/Drug_Discovery/OneDrive_1_2022-12-13/Table 2 - Done/'
query_root = '/home1/yueming/Drug_Discovery/Datasets/'
output_root = '/home1/yueming/Drug_Discovery/Baselines/GIGN-main/GIGN/dataset/'
set_list = ['all', 'test', 'train_1', 'train_2', 'train_3', 'train_4', 'train_5']

def re_normalize(source_df, query_path, set_name, output_path, output_assay_name):
    new_column_names = {'value': output_assay_name, 'smiles': 'SMILES', 'ChEMBL_Compound_ID': 'ChEMBL_Compound_ID'}
    # 读取CSV文件
    query_df = pd.read_csv(query_path)
    file_path = query_path.replace('all', set_name)
    df = pd.read_csv(file_path)

    # 获取分子SMILES列的数据
    smi_column_name = 'SMILES' if set_name == 'all' else 'smiles'
    value_column_name = 'Standard_Value' if set_name == 'all' else 'value'
    smiles_list = df[smi_column_name].tolist()
    
    for smiles in tqdm(smiles_list):
        ChEMBL_Compound_ID = query_df[query_df['SMILES'].values == smiles].ChEMBL_Compound_ID.values
        single_df = source_df[source_df['ChEMBL Compound ID'].values == ChEMBL_Compound_ID]
        row_assay_values = single_df['Standard Value'].values
        assay_values = [x for x in row_assay_values if not math.isnan(x)]
        assay_value = np.mean(assay_values)
        standard_assay_value = - np.log10(assay_value) + 9  # 1 M = 10^9 nM
        row_indices = df[df[smi_column_name].values == smiles].index
        df.loc[row_indices, value_column_name] = standard_assay_value
        df.loc[row_indices, 'ChEMBL_Compound_ID'] = ChEMBL_Compound_ID
    if set_name == 'all':
        df['Standard_Type'] = output_assay_name
        del df['Standard_Units']
        for s in set_list[1:]:
            subset_df = pd.read_csv(query_path.replace('all', s))
            for smi in subset_df['smiles']:
                row_indice = df[df['SMILES'].values == smi].index
                df.loc[row_indice, 'Subset'] = s
    else:
        df = df.rename(columns=new_column_names)
    df.to_csv(output_path, index=False)
    
for key, value in tqdm(target_assay_dict.items()):
    task_name = "_".join([key, value[0]])
    output_assay_name = task_name_convert_dict[value[0]]
    source_df = pd.read_csv(row_activity_root + f'{key} - table 2.csv')
    query_path = query_root + key + f'/{value[0]}/{task_name}' + '_all.csv'
    output_dir = output_root + f'{key}/{output_assay_name}/'
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    for set_name in set_list:
        output_path = output_dir + f'{key}_{output_assay_name}_{set_name}.csv'
        if key != 'CHEMBL202':
            print(f'Processing {key} {output_assay_name} {set_name}')
            re_normalize(source_df, query_path, set_name, output_path, output_assay_name)

# Read and save docking results for QVina-W

In [None]:
import os, re
from tqdm import tqdm
from rdkit import Chem
import pandas as pd

# List of input SDF files
result_root = '/home1/yueming/Drug_Discovery/Baselines/GIGN-main/GIGN/data/CHEMBL202/1boz/qvina_pocket/'
output_csv_dir = '/home1/yueming/Drug_Discovery/Baselines/GIGN-main/GIGN/data/CHEMBL202/1boz/qvina_pocket/'
input_sets = os.listdir(result_root)

# 定义正则表达式模式，只取top-1结果
pattern = r'^\s+1\s+(-?\d+\.\d+)'
pattern_ = r'^\s+1\s+(-?\d+)'

def extract_smiles_from_pdbqt(file_path):
    smiles = None
    with open(file_path, 'r') as file:
        for line in file:
            if line.startswith('REMARK SMILES'):
                smiles = line.split('REMARK SMILES')[1].strip()
                break
    return smiles

for set_name in tqdm(input_sets):
    result_dir = result_root + set_name + '/'
    # 创建一个空的 DataFrame
    df = pd.DataFrame(columns=['ChEMBL_Compound_ID', 'SMILES', 'Vina_Score_1', 'Vina_Score_2', 'Vina_Score_3', 'Vina_Score_4', 
                               'Vina_Score_5', 'Vina_Score_6', 'Vina_Score_7'])
    for input_file in os.listdir(result_dir):
        if input_file[-5:] == 'pdbqt':
            file_name = input_file[:-6]
            cpd = file_name.split('_')[0]
            pocket = file_name.split('_')[-1]
            # 读取文本文件
            txt_path = result_dir + file_name + '.txt' # need to be coverted to "input_file.replace('pdbqt', 'txt')" after being corrected
            with open(txt_path, 'r') as txt:
                lines = txt.readlines()
            smiles = extract_smiles_from_pdbqt(result_dir + input_file)
            # 提取符合模式的行，并保存到列表中
            for line in lines:
                match = re.findall(pattern, line, flags=re.MULTILINE)
                if match:
                    affinity = float(match[0])
                    break
            if cpd in df['ChEMBL_Compound_ID'].values:
                condition = (df['ChEMBL_Compound_ID']==cpd)
                df.loc[condition, f'Vina_Score_{pocket}'] = affinity
            else:
                init_content_list = [cpd, smiles] + [None] * 7
                df.loc[len(df)] = init_content_list
                df.loc[len(df)-1, f'Vina_Score_{pocket}'] = affinity

    # 保存更新后的CSV表格
    df.to_csv(output_csv_dir + f'{set_name}/vina_results.csv', index=False)

# Save Vina scores into the datasets

In [None]:
import os, re
from tqdm import tqdm
from rdkit import Chem
import pandas as pd

# List of input SDF files
result_dir = '/home1/yueming/Drug_Discovery/Baselines/GIGN-main/GIGN/data/CHEMBL202/1boz/qvina_pocket/'
set_list = ['test', 'train_1', 'train_2', 'train_3', 'train_4', 'train_5']
csv_dir = f'{result_dir}/CHEMBL202_pIC50_'
output_file = csv_dir
    
# 定义正则表达式模式
pattern = r'^\s+(1)\s+(-?\d+\.\d+)'
pattern_ = r'^\s+(1)\s+(-?\d+)'
# new_column_names = {'value': 'pIC50', 'smiles': 'SMILES', 'ChEMBL_Compound_ID': 'ChEMBL_Compound_ID'}

for dataset in tqdm(set_list):
    dataset_path = result_dir + dataset + '/'
    dataset_df = pd.read_csv(csv_dir + dataset + '.csv')
#     dataset_df = dataset_df.rename(columns=new_column_names)
    # Loop through input files
    for input_file in os.listdir(dataset_path):
        if input_file[-5:] == 'pdbqt':
            # 读取文本文件
            file_name = input_file[:-6]
            cpd = file_name.split('_')[0]
            pocket = file_name.split('_')[-1]
            txt_path = dataset_path + file_name + '.txt'
            with open(txt_path, 'r') as file:
                lines = file.readlines()

            # 提取符合模式的行，并保存到列表中
            for line in lines:
                match_ = re.match(pattern, line)
                match__ = re.match(pattern_, line)
                match = match_ if match_ else match__
                if match:
                    number = int(match.group(1))
                    affinity = float(match.group().rsplit(maxsplit=1)[-1])
                    # 设置条件
                    condition = (dataset_df['ChEMBL_Compound_ID'] == cpd)  # 示例条件，可根据实际情况修改
                    # 根据条件筛选满足条件的行索引
                    row_indices = dataset_df[condition].index
                    column_name = f'Pocket_{pocket}_Vina_Score'  # 列名，根据文件名索引生成
                    dataset_df.loc[row_indices, column_name] = affinity

    # 保存更新后的CSV表格
    dataset_df.to_csv(output_file + dataset + '.csv', index=False)


In [None]:
import os
from tqdm import tqdm
from rdkit import Chem

data_root = '/home1/yueming/Drug_Discovery/Baselines/GIGN-main/GIGN/data/'
task_dict = {0: ('CHEMBL202', 'pIC50'), 1: ('CHEMBL235', 'pEC50'), 2: ('CHEMBL279', 'pIC50'), 3: ('CHEMBL2971', 'pIC50'), 
             4: ('CHEMBL333', 'pIC50'), 5: ('CHEMBL3820', 'pEC50'), 6: ('CHEMBL3976', 'pIC50'), 7: ('CHEMBL4422', 'pEC50')}
set_list = ['test', 'train_1', 'train_2', 'train_3', 'train_4', 'train_5']

for key, value in task_dict.items():
    target, assay = value
    # List of input SDF files
    input_dir = f'{data_root}/{target}/ligand/sdf/'
    output_dir_addH = input_dir
    output_dir_pdbqt = f'{data_root}/{target}/ligand/pdbqt/'

    # Loop through input files
    for set_name in tqdm(set_list):
        input_files = os.listdir(input_dir + set_name)
        for input_file in tqdm(input_files):
            input_file_path = input_dir + set_name + '/' + input_file
            # Load the molecule from the SDF file
    #         mol = Chem.SDMolSupplier(input_file_path)[0]

    #         # Add explicit hydrogens
    #         mol = Chem.AddHs(mol)

    #         # Save the modified molecule with explicit hydrogens
    #         sdf_addH_save_path = output_dir_addH + set_name + '/'
    #         if not os.path.exists(sdf_addH_save_path):
    #                 os.makedirs(sdf_addH_save_path)
    #         Chem.SDWriter(sdf_addH_save_path + input_file).write(mol)

            # Construct the command
            pdbqt_save_path = output_dir_pdbqt + set_name + '/'
            if not os.path.exists(pdbqt_save_path):
                    os.makedirs(pdbqt_save_path)
            command = f"mk_prepare_ligand.py -i {input_file_path} -o {pdbqt_save_path + input_file[:-4]}.pdbqt"

            # Execute the command
            !{command}

# Preprocessing PARP1 datasets

## Generate table from the web scrapping data

In [None]:
# 导入必要的库
import re
import pandas as pd

# 定义文件路径
input_file = '/home1/yueming/Drug_Discovery/Baselines/GIGN-main/GIGN/data/PARP1/activity/IC50/Web_Scrapper_Data.txt'
output_file = '/home1/yueming/Drug_Discovery/Baselines/GIGN-main/GIGN/data/PARP1/activity/IC50/Web_Scrapper_Data.csv'
output_txt = '/home1/yueming/Drug_Discovery/Baselines/GIGN-main/GIGN/data/PARP1/activity/IC50/Web_Scrapper_Data_Rowwise.txt'

# 打开文件并读取文本
with open(input_file, 'r') as file:
    text = file.read()

# 在每个"COMPOUND"之前添加换行符
text = text.replace("COMPOUND:", "\nCOMPOUND:")

# 打开目标文件并写入修改后的文本内容
with open(output_txt, 'w') as file:
    file.write(text)

print("修改后的文本已保存到新文件。")

# 用于存储数据的列表
data = {'COMPOUND': [], 'ASSAY TYPE': [], 'IC50': [], 'PUBMED': [], 'Source': []}

# 正则表达式用于匹配包含PUBMED的行
pattern_with_pubmed = r'COMPOUND: (\d+) ASSAY TYPE (\w+) IC50 (.*?)nM PUBMED (\d+) Source: ([\w]+)'

# 正则表达式用于匹配不包含PUBMED的行
pattern_without_pubmed = r'COMPOUND: (\d+) ASSAY TYPE (\w+) IC50 (.*?)nM Source: ([\w]+)'

matches_with_pubmed = re.finditer(pattern_with_pubmed, text, flags=re.IGNORECASE)
matches_without_pubmed = re.finditer(pattern_without_pubmed, text, flags=re.IGNORECASE)

# 处理包含PUBMED的行
for match in matches_with_pubmed:
    print("Matched line:", match.group(0))  # 打印整行的匹配内容
    data['COMPOUND'].append(match.group(1))
    data['ASSAY TYPE'].append(match.group(2))
    data['IC50'].append(match.group(3)+'nM')
    data['PUBMED'].append(match.group(4))
    data['Source'].append(match.group(5))

# 处理不包含PUBMED的行
for match in matches_without_pubmed:
    print("Matched line:", match.group(0))  # 打印整行的匹配内容
    data['COMPOUND'].append(match.group(1))
    data['ASSAY TYPE'].append(match.group(2))
    data['IC50'].append(match.group(3)+'nM')
    data['PUBMED'].append('')  # 如果没有PUBMED，将PUBMED列设置为空字符串
    data['Source'].append(match.group(4))

# 创建DataFrame
df = pd.DataFrame(data)
lines = text.split('\n')
print("文本行数:", len(lines)-1)
print("表格行数:", len(df))

# 保存为CSV文件
df.to_csv(output_file, index=False)

print("数据已保存为CSV文件。")


## Search and save SMILES according to PUBMED ID

In [None]:
import pandas as pd

# 定义输入CSV文件路径和输出CSV文件路径
input_csv = '/home1/yueming/Drug_Discovery/Baselines/GIGN-main/GIGN/data/PARP1/activity/IC50/Web_Scrapper_Data.csv'
output_csv = '/home1/yueming/Drug_Discovery/Baselines/GIGN-main/GIGN/data/PARP1/activity/IC50/Web_Scrapper_Data_With_SMILES.csv'
smiles_csv = '/home1/yueming/Drug_Discovery/Baselines/GIGN-main/GIGN/data/PARP1/activity/IC50/parp1_ecfp_integer.csv'

# 读取输入CSV文件
df = pd.read_csv(input_csv)

# 读取包含SMILES的CSV文件
smiles_df = pd.read_csv(smiles_csv)

# 创建一个字典，用于将COMPOUND_ID映射到SMILES
compound_id_to_smiles = dict(zip(smiles_df['COMPOUND_ID'], smiles_df['SMILES']))

# 根据COMPOUND列匹配SMILES并添加到新的SMILES列
df['SMILES'] = df['COMPOUND'].map(compound_id_to_smiles)

# 统计匹配到了多少行和没有匹配到多少行
matched_rows = df['SMILES'].count()
unmatched_rows = len(df) - matched_rows

# 保存包含新SMILES列的CSV文件
df.to_csv(output_csv, index=False)

# 打印统计信息
print(f"总共行数: {len(df)}")
print(f"匹配到的行数: {matched_rows}")
print(f"未匹配到的行数: {unmatched_rows}")
print("已添加SMILES列并保存为CSV文件。")


## Split to data subsets according to their Bemis-Murcko scaffolds

## Distributions of Bemis-Murcko scaffolds in data subsets

In [None]:
import pandas as pd
from rdkit import Chem
from rdkit.Chem.Scaffolds import MurckoScaffold
import random
import matplotlib.pyplot as plt
import numpy as np
import warnings
import re

warnings.filterwarnings('ignore')

# 设置Matplotlib字体大小
plt.rcParams.update({'font.size': 14})

# 读取CSV文件
input_csv = '/home1/yueming/Drug_Discovery/Baselines/GIGN-main/GIGN/data/PARP1/activity/IC50/Web_Scrapper_Data_With_ALL_SMILES.csv'
df = pd.read_csv(input_csv)

# 定义一个函数来提取IC50值并将其转换为浮点数
def extract_ic50(ic50_str):
    ic50_match = re.search(r'[<>]?=?\s*(\d+(\.\d+)?)\s*[nN][mM]', ic50_str)
    if ic50_match:
        return float(ic50_match.group(1))
    else:
        return None

# 添加新的pIC50列
df['pIC50'] = df['IC50'].apply(lambda x: 9 - np.log10(extract_ic50(x)) if extract_ic50(x) else None)

# 计算每个COMPOUND的pIC50中位数
df['Median_pIC50'] = df.groupby('COMPOUND')['pIC50'].transform('median')

# 选择中位数最高的行
df = df[df['pIC50'] == df['Median_pIC50'].fillna(-np.inf)]

# 标记重复行（保留每个COMPOUND组的第一行）
df['IsDuplicate'] = df.duplicated(subset=['COMPOUND'], keep='first')

# 选择IsDuplicate为False的行
df = df[~df['IsDuplicate']]

# 删除IsDuplicate列（如果不再需要）
df = df.drop(['IsDuplicate'], axis=1)

# 创建一个空字典用于存储分子的Bemis-Murcko scaffolds与子集的映射
scaffold_to_subset, scaffold_to_index = {}, {}

# 获取所有唯一的scaffold类型
unique_scaffolds = df['SMILES'].apply(lambda x: Chem.MolToSmiles(MurckoScaffold.GetScaffoldForMol(Chem.MolFromSmiles(x)))).unique()

# 将scaffold类型分配到6个子集中
num_subsets = 6
scaffolds_per_subset = len(unique_scaffolds) // num_subsets
remainder = len(unique_scaffolds) % num_subsets

# 将唯一的scaffold顺序打乱并分配给子集
random.shuffle(unique_scaffolds)

scaffold_counter = 0
scaffold_index_counter = 0
subset_labels = {
    1: 'Train #1',
    2: 'Train #2',
    3: 'Train #3',
    4: 'Train #4',
    5: 'Train #5',
    6: 'Test'
}

for i in range(1, num_subsets + 1):
    num_scaffolds = scaffolds_per_subset
    if i <= remainder:
        num_scaffolds += 1
    subset_scaffolds = unique_scaffolds[scaffold_counter: scaffold_counter + num_scaffolds]
    scaffold_counter += num_scaffolds
    for scaffold in subset_scaffolds:
        scaffold_to_subset[scaffold] = i
        scaffold_to_index[scaffold] = scaffold_index_counter
        scaffold_index_counter += 1

# 将分子分配到子集中
df['Subset'] = df['SMILES'].apply(lambda x: subset_labels[scaffold_to_subset.get(Chem.MolToSmiles(MurckoScaffold.GetScaffoldForMol(Chem.MolFromSmiles(x))), 0)])

# 在此添加Scanfold Index列
df['Scaffold_Index'] = df['SMILES'].apply(lambda x: scaffold_to_index.get(Chem.MolToSmiles(MurckoScaffold.GetScaffoldForMol(Chem.MolFromSmiles(x))), 0))

# 在此添加Scaffold_SMILES列
df['Scaffold_SMILES'] = df['SMILES'].apply(lambda x: Chem.MolToSmiles(MurckoScaffold.GetScaffoldForMol(Chem.MolFromSmiles(x))))

# 统计每个子集中包含各scaffold索引的数量
subset_scaffold_counts = df.groupby(['Subset'])['Scaffold_Index'].nunique().reset_index(name='Scaffold_Count')

# 计算每个子集中的总分子数
subset_total_counts = df.groupby(['Subset']).size().reset_index(name='Total_Count')

# 创建柱状图
fig, ax = plt.subplots(figsize=(10, 6))

# 先获取 "Test" 行的数据
test_row = subset_total_counts[subset_total_counts['Subset'] == 'Test']
# 然后从 subset_total_counts 中删除 "Test" 行
subset_total_counts = subset_total_counts[subset_total_counts['Subset'] != 'Test']
# 将 "Test" 行添加到 DataFrame 的末尾
subset_total_counts = pd.concat([subset_total_counts, test_row], ignore_index=True)
# 先获取 "Test" 行的数据
test_row = subset_scaffold_counts[subset_scaffold_counts['Subset'] == 'Test']
# 然后从 subset_total_counts 中删除 "Test" 行
subset_scaffold_counts = subset_scaffold_counts[subset_scaffold_counts['Subset'] != 'Test']
# 将 "Test" 行添加到 DataFrame 的末尾
subset_scaffold_counts = pd.concat([subset_scaffold_counts, test_row], ignore_index=True)

# 左侧y轴，表示每个子集中的唯一scaffold数量
bar1 = ax.bar(subset_total_counts['Subset'], subset_total_counts['Total_Count'], color='skyblue', alpha=0.7, label='Molecules')
bar2 = ax.bar(subset_scaffold_counts['Subset'], subset_scaffold_counts['Scaffold_Count'], color='lightcoral', alpha=1, label='Unique Scaffolds')
ax.set_xlabel('Data Subset')
ax.set_ylabel('Data Number')
ax.set_title('Number of Unique Scaffolds and Total Molecules in Each Data Subset')

# 使用 subset_labels 字典的值作为 x 轴标签
ax.set_xticks(np.arange(0, num_subsets))  # 增加一个位置以容纳 "Test"
ax.set_xticklabels([subset_labels[i] for i in range(1, num_subsets + 1)])

# 图例
legend = ax.legend(loc='upper left')#, bbox_to_anchor=(1.01, 0.15)

plt.grid(axis='y', linestyle='--', alpha=0.7)
# 标注柱子上的纵坐标值
def autolabel(bars):
    for bar in bars:
        height = bar.get_height()
        ax.annotate(f'{height}',
                    xy=(bar.get_x() + bar.get_width() / 2, 5.5*height/7),
                    xytext=(0, 3),  # 3 points vertical offset
                    textcoords="offset points",
                    ha='center', va='bottom')

autolabel(bar1)
autolabel(bar2)
plt.show()

# 保存分子子集为不同的CSV文件
output_dir = '/home1/yueming/Drug_Discovery/Baselines/GIGN-main/GIGN/data/PARP1/activity/IC50/'
for i in range(1, num_subsets + 1):
    subset_df = df[df['Subset'] == subset_labels[i]]
    
    # 添加Subset列
    subset_df['Subset'] = subset_labels[i]

    if i == num_subsets:
        subset_csv = f'{output_dir}test.csv'
    else:
        subset_csv = f'{output_dir}train_{i}.csv'
    # subset_df.to_csv(subset_csv, index=False) # avoid bad replacement
    print(f"{subset_labels[i]}: {len(subset_df)}")

print("分子子集已保存为CSV")


In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import re
import numpy as np
from scipy.stats import gaussian_kde

# 定义子集文件路径列表
subset_files = [
    '/home1/yueming/Drug_Discovery/Baselines/GIGN-main/GIGN/data/PARP1/activity/IC50/train_1.csv',
    '/home1/yueming/Drug_Discovery/Baselines/GIGN-main/GIGN/data/PARP1/activity/IC50/train_2.csv',
    '/home1/yueming/Drug_Discovery/Baselines/GIGN-main/GIGN/data/PARP1/activity/IC50/train_3.csv',
    '/home1/yueming/Drug_Discovery/Baselines/GIGN-main/GIGN/data/PARP1/activity/IC50/train_4.csv',
    '/home1/yueming/Drug_Discovery/Baselines/GIGN-main/GIGN/data/PARP1/activity/IC50/train_5.csv',
    '/home1/yueming/Drug_Discovery/Baselines/GIGN-main/GIGN/data/PARP1/activity/IC50/test.csv'
]

# 创建一个空DataFrame来存储提取的数据
data = pd.DataFrame(columns=['Subset', 'Scaffold_Index', 'pIC50'])

# 定义一个函数来提取IC50值并将其转换为浮点数
def extract_ic50(ic50_str):
    ic50_match = re.search(r'[<>]?=?\s*(\d+(\.\d+)?)\s*[nN][mM]', ic50_str)
    if ic50_match:
        return float(ic50_match.group(1))
    else:
        return None

# 遍历子集文件
for subset_num, subset_file in enumerate(subset_files, start=1):
    subset_df = pd.read_csv(subset_file)
    
    # 添加Subset列
    subset_df['Subset'] = subset_num
    
    # 添加数据到data DataFrame
    data = pd.concat([data, subset_df], ignore_index=True)

    
subset_labels = {
    1: 'Train #1',
    2: 'Train #2',
    3: 'Train #3',
    4: 'Train #4',
    5: 'Train #5',
    6: 'Test'
}
# print(data)
# 创建二维可视化图
fig, ax1 = plt.subplots(figsize=(13, 6))
colors = ['b', 'g', 'r', 'c', 'm', 'y']
markers = ['o', 's', '^', 'D', 'v', 'p']

# 设置Matplotlib字体大小
plt.rcParams.update({'font.size': 16})

for subset_num, color, marker in zip(range(1, 7), colors, markers):
    subset_data = data[data['Subset'] == subset_num]
    scaffold_indices = subset_data['Scaffold_Index']
    ic50_values = subset_data['pIC50']
    ax1.scatter(scaffold_indices + np.random.randn(len(scaffold_indices)) * 0.1, ic50_values, label=f'{subset_labels[subset_num]}', color=color, alpha=0.5, marker=marker)

ax1.set_xlabel('Scaffold Index')
ax1.set_ylabel('pIC50')
ax1.set_title('pIC50 Values for Different Scaffolds in Subsets')
legend = ax1.legend(loc='upper center', ncol=3, fontsize=12) # , bbox_to_anchor=(0.5, -0.15)
legend.set_title('', prop={'size': 12})

# 在右边添加密度曲线
ax2 = ax1.inset_axes([1.08, 0, 0.2, 1])
ax2.yaxis.tick_right()  # 将第二个y轴的ticks移到右边
ax2.yaxis.set_label_position('right')  # 将y轴标签移到右边
ax2.set_ylabel('pIC50')
ax2.set_xlabel('Density')
for subset_num, color in zip(range(1, 7), colors):
    subset_data = data[data['Subset'] == subset_num]
    ic50_values = subset_data['pIC50']
    
    # 使用核密度估计计算密度曲线
    kde = gaussian_kde(ic50_values)
    x_vals = np.linspace(ic50_values.min(), ic50_values.max(), 100)
    y_vals = kde(x_vals)
    
    ax2.plot(y_vals, x_vals, color=color)

ax2.set_xlim(0.5, 0)  # 根据数据的范围调整x轴限制
ax2.invert_xaxis()   # 反转x轴，使得底边为y轴
ax2.grid(True)
# ax2.set_title('Density', loc='center')  # 添加密度曲线的标题
plt.tight_layout()
plt.show()

## Save SDF for Compounds in Data Subsets

In [None]:
import pandas as pd
from tqdm import tqdm
from rdkit import Chem
from rdkit.Chem import AllChem
import os

activity_root = '/home1/yueming/Drug_Discovery/Baselines/GIGN-main/GIGN/data/PARP1/activity/IC50/'
target_root = '/home1/yueming/Drug_Discovery/Baselines/GIGN-main/GIGN/data/PARP1/'
set_list = ['test', 'train_1', 'train_2', 'train_3', 'train_4', 'train_5']

def save_sdf(activity_path, save_path):
    # 读取CSV文件
    df = pd.read_csv(activity_path)

    # 获取分子SMILES列的数据
    smiles_list = df['SMILES'].tolist()
    id_list = df['COMPOUND'].tolist()

    # 计算并保存SDF文件
    for index, smiles in zip(id_list, smiles_list):
        output_file = save_path + f'canSAR{index}.sdf'
        
        if os.path.exists(output_file):
            continue
        try:
            mol = Chem.MolFromSmiles(smiles)
            if mol is None:
                print('Cannot read SMILES:', smiles)
            hmol = Chem.AddHs(mol)
            AllChem.EmbedMolecule(hmol,AllChem.ETKDG())
            AllChem.UFFOptimizeMolecule(hmol,1000)
            if not os.path.exists(save_path):
                os.makedirs(save_path)
            writer = Chem.SDWriter(output_file)
            hmol.SetProp("_SMILES","%s"%smiles)
            writer.write(hmol)
            writer.close()
        except:
            print(f'Fail to save SDF for the compound: {index} in the {set_name} set')

# 找出每个子集中已保存的SDF文件
saved_sdf_files = set()

for set_name in tqdm(set_list):
    activity_path = f'{activity_root}/{set_name}.csv'
    save_path = f'{target_root}/ligand/sdf/{set_name}/'
    save_sdf(activity_path, save_path)
    
    # 获取已保存的SDF文件
    saved_sdf_files.update(os.listdir(save_path))

# 找到未保存为SDF文件的COMPOUND序号
for set_name in set_list:
    activity_path = f'{activity_root}/{set_name}.csv'
    df = pd.read_csv(activity_path)
    
    all_compounds = df['COMPOUND'].tolist()
    
    unsaved_compounds = [compound for compound in all_compounds if f'canSAR{compound}.sdf' not in saved_sdf_files]
    
    print(f"未保存为SDF文件的COMPOUND序号 ({set_name}):")
    print(unsaved_compounds)


## Save smi list for Compounds in Data Subsets

In [None]:
import os
import pandas as pd

# 定义文件路径列表
subset_files = [
    '/home1/yueming/Drug_Discovery/Baselines/GIGN-main/GIGN/data/PARP1/activity/IC50/train_1.csv',
    '/home1/yueming/Drug_Discovery/Baselines/GIGN-main/GIGN/data/PARP1/activity/IC50/train_2.csv',
    '/home1/yueming/Drug_Discovery/Baselines/GIGN-main/GIGN/data/PARP1/activity/IC50/train_3.csv',
    '/home1/yueming/Drug_Discovery/Baselines/GIGN-main/GIGN/data/PARP1/activity/IC50/train_4.csv',
    '/home1/yueming/Drug_Discovery/Baselines/GIGN-main/GIGN/data/PARP1/activity/IC50/train_5.csv',
    '/home1/yueming/Drug_Discovery/Baselines/GIGN-main/GIGN/data/PARP1/activity/IC50/test.csv'
]

# 创建保存 smi 文件的目录
output_dir = '/home1/yueming/Drug_Discovery/Baselines/GIGN-main/GIGN/data/PARP1/ligand/smi'
os.makedirs(output_dir, exist_ok=True)

# 遍历文件并生成 smi 文件
for subset_file in subset_files:
    # 读取 CSV 文件
    df = pd.read_csv(subset_file)
    
    # 在 "COMPOUND" 列的值前添加 "canSAR"
    df['COMPOUND'] = 'canSAR' + df['COMPOUND'].astype(str)
    
    # 提取 "COMPOUND" 列和 "SMILES" 列的值
    compounds = df['COMPOUND'].tolist()
    smiles = df['SMILES'].tolist()
    
    # 创建 smi 文件路径
    filename = os.path.join(output_dir, os.path.basename(subset_file).replace('.csv', '.smi'))
    
    # 写入 smi 文件
    with open(filename, 'w') as smi_file:
        for compound, smile in zip(compounds, smiles):
            smi_file.write(f"{smile} {compound}\n")

print("SMI 文件已生成并保存在指定目录下")


## Prepare ligands in pdbqt format for molecular docking

In [None]:
import os
from tqdm import tqdm
from rdkit import Chem

data_root = '/home1/yueming/Drug_Discovery/Baselines/GIGN-main/GIGN/data/PARP1/'
set_list = ['test', 'train_1', 'train_2', 'train_3', 'train_4', 'train_5']

input_dir = f'{data_root}/ligand/sdf/'
output_dir_addH = f'{data_root}/ligand/addH/'
output_dir_pdbqt = f'{data_root}/ligand/pdbqt/'

# Loop through input files
for set_name in tqdm(set_list):
    input_files = os.listdir(input_dir + set_name)
    for input_file in tqdm(input_files):
        input_file_path = input_dir + set_name + '/' + input_file
        # Load the molecule from the SDF file
        mol = Chem.SDMolSupplier(input_file_path)[0]

        # Add explicit hydrogens
        mol = Chem.AddHs(mol)

        # Save the modified molecule with explicit hydrogens
        sdf_addH_save_path = output_dir_addH + set_name + '/'
        if not os.path.exists(sdf_addH_save_path):
                os.makedirs(sdf_addH_save_path)
        Chem.SDWriter(sdf_addH_save_path + input_file).write(mol)

        # Construct the command
        pdbqt_save_path = output_dir_pdbqt + set_name + '/'
        if not os.path.exists(pdbqt_save_path):
                os.makedirs(pdbqt_save_path)
        output_path = pdbqt_save_path + input_file[:-4] + '.pdbqt'
        if not os.path.exists(output_path):
            command = f"mk_prepare_ligand.py -i {input_file_path} -o {output_path}"
            # Execute the command
            !{command}

## Do AutoDock Vina on a batch of ligands and one protein

In [None]:
import subprocess
import pandas as pd
from tqdm import tqdm
import os

protein_name, pdb_name, set_index, pocket_index = 'PARP1', '6nrh', 5, 4
# 定义输入输出文件路径列表
input_protein = f'/home1/yueming/Drug_Discovery/Baselines/GIGN-main/GIGN/data/{protein_name}/protein/{pdb_name}_protein.pdbqt'
input_dir_ligand = f'/home1/yueming/Drug_Discovery/Baselines/GIGN-main/GIGN/data/{protein_name}/ligand/pdbqt/'
input_config_dir = f'/home1/yueming/Drug_Discovery/Baselines/GIGN-main/GIGN/data/{protein_name}/protein/'
output_dir = f'/home1/yueming/Drug_Discovery/Baselines/GIGN-main/GIGN/data/{protein_name}/{pdb_name}/vina/docking_data/'
input_sets = [['test', 'train_1', 'train_2', 'train_3', 'train_4', 'train_5'][set_index]]
pocket_sets = [['pocket_1', 'pocket_2', 'pocket_3', 'protein_center', 'pocket_active'][pocket_index]]
command_template = "vina --receptor {} --ligand {} --config {} --exhaustiveness 32 --out {}"
# 循环处理每个文件路径
for set_name in tqdm(input_sets):
    for pocket_name in pocket_sets:
        input_config = input_config_dir + f'vina_{pocket_name}.txt'
        input_ligands = os.listdir(input_dir_ligand + set_name)
        for input_ligand in tqdm(input_ligands):
            # 提取文件名（不包含扩展名）
            filename = input_ligand.split('.')[0]
            output_path = f'{output_dir}{set_name}/{pocket_name}/{filename}.pdbqt'
            os.makedirs(os.path.dirname(output_path), exist_ok=True)
            if os.path.exists(output_path + '.txt'):
                os.rename(output_path + '.txt', output_path.replace('pdbqt', 'txt'))
            if os.path.exists(output_path):
                continue
            # 构建命令
            command = command_template.format(input_protein, input_dir_ligand + set_name + '/' + input_ligand, input_config, 
                                              output_path)
            # 执行命令并捕获输出结果
            output = subprocess.run(command, shell=True, capture_output=True, text=True)

            # 将输出结果保存到文件
            with open(output_path.replace('pdbqt', 'txt'), 'w') as f:
                f.write(output.stdout)
            


## Do QVina-W on a batch of ligands and one protein

In [None]:
import os
from tqdm import tqdm

protein_name, pdb_name, set_index = 'PARP1', '6nrh', 0
receptor_file = f'/home1/yueming/Drug_Discovery/Baselines/GIGN-main/GIGN/data/{protein_name}/protein/{pdb_name}_protein.pdbqt'
ligand_pdbqt_dir = f'/home1/yueming/Drug_Discovery/Baselines/GIGN-main/GIGN/data/{protein_name}/ligand/pdbqt/'
config_file = f'/home1/yueming/Drug_Discovery/Baselines/GIGN-main/GIGN/data/{protein_name}/protein/qvina_w.txt'
output_root = f'/home1/yueming/Drug_Discovery/Baselines/GIGN-main/GIGN/data/{protein_name}/{pdb_name}/qvina_w/docking_data/'
qvina_w_path = f'/home1/yueming/Drug_Discovery/Baselines/qvina'
input_sets = [['test', 'train_1', 'train_2', 'train_3', 'train_4', 'train_5'][set_index]]
# 循环执行命令
for set_name in tqdm(input_sets):
    output_dir = output_root + set_name + '/'
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    input_files = os.listdir(ligand_pdbqt_dir + set_name)
    for input_file in tqdm(input_files):
        if not os.path.exists(f'{output_dir}/{input_file}'):
            input_file_path = ligand_pdbqt_dir + set_name + '/' + input_file
            with open(config_file, 'r') as f:
                config_lines = f.readlines()
            config_lines[0] = f'receptor = {receptor_file}\n'
            config_lines[1] = f'ligand = {ligand_pdbqt_dir + set_name}/{input_file}\n'
            config_lines[2] = f'out  = {output_dir}/{input_file}\n'
            config_lines[3] = f'log  = {output_dir}/{input_file}\n'
            with open(config_file, 'w') as f:
                f.writelines(config_lines)

            # 执行命令
            command = f'{qvina_w_path}/qvina-w_serial --config {config_file}'
            os.system(command)


## Do DiffDock on a batch of ligands and one protein

In [None]:
import os
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

protein_name, pdb_name, set_index = 'PARP1', '6nrh', 0
receptor_file = f'/home1/yueming/Drug_Discovery/Baselines/GIGN-main/GIGN/data/{protein_name}/protein/{pdb_name}_protein.pdb'
ligand_sdf_dir = f'/home1/yueming/Drug_Discovery/Baselines/GIGN-main/GIGN/data/{protein_name}/ligand/sdf/'
output_root = f'/home1/yueming/Drug_Discovery/Baselines/GIGN-main/GIGN/data/{protein_name}/{pdb_name}/diffdock/docking_data/'
diffdock_path = f'/home1/yueming/Drug_Discovery/Baselines/DiffDock-main'
input_sets = [['test', 'train_1', 'train_2', 'train_3', 'train_4', 'train_5'][set_index]]
# 设置工作目录为包含 inference.py 文件的目录
os.chdir(diffdock_path)
# 循环执行命令
for set_name in tqdm(input_sets):
    output_dir = output_root + set_name + '/'
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    input_files = os.listdir(ligand_sdf_dir + set_name)
    for input_file in tqdm(input_files):
        filename = input_file.split('.')[0]
        if not os.path.exists(f'{output_dir}/{filename}/rank1.sdf'):
            input_file_path = ligand_sdf_dir + set_name + '/' + input_file
            # 执行命令
            command = f'python {diffdock_path}/inference.py --complex_name {filename} --protein_path {receptor_file} --ligand {input_file_path} --out_dir {output_dir} --inference_steps 20 --samples_per_complex 10 --batch_size 10 --actual_steps 18 --no_final_step_noise'
            os.system(command)


## Do GNINA on a batch of ligands and one protein

In [None]:
# cd /home1/yueming/Drug_Discovery/Baselines/gnina
# sudo docker run -v /home1:$(pwd)/home1 -it gnina/gnina
# /home1/yueming/Drug_Discovery/Baselines/gnina/home1/yueming/anaconda3/bin/python /home1/yueming/Drug_Discovery/Baselines/gnina/home1/yueming/Drug_Discovery/Baselines/gnina/run_gnina.py
# sudo chown -R yueming:yueming /home1/yueming/Drug_Discovery/Baselines/GIGN-main/GIGN/data/PARP1/6nrh/gnina
# rm -r /home1/yueming/Drug_Discovery/Baselines/GIGN-main/GIGN/data/PARP1/6nrh/gnina/test

import os
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

protein_name, pdb_name, set_index = 'PARP1', '6nrh', 0
gnina_path = f'/home1/yueming/Drug_Discovery/Baselines/gnina'
data_root = gnina_path + f'/home1/yueming/Drug_Discovery/Baselines/GIGN-main/GIGN/data/{protein_name}'
receptor_file = gnina_path + f'/home1/yueming/Drug_Discovery/Baselines/GIGN-main/GIGN/data/{protein_name}/protein/{pdb_name}_protein.pdb'
ligand_pdbqt_dir = gnina_path + f'/home1/yueming/Drug_Discovery/Baselines/GIGN-main/GIGN/data/{protein_name}/ligand/sdf/'
output_root = gnina_path + f'/home1/yueming/Drug_Discovery/Baselines/GIGN-main/GIGN/data/{protein_name}/{pdb_name}/gnina/docking_data/'
input_sets = [['test', 'train_1', 'train_2', 'train_3', 'train_4', 'train_5'][set_index]]
# 设置工作目录为包含 inference.py 文件的目录
os.chdir(gnina_path)
# 循环执行命令
for set_name in tqdm(input_sets):
    output_dir = output_root + set_name + '/'
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    input_files = os.listdir(ligand_pdbqt_dir + set_name)
    for input_file in tqdm(input_files):
        filename = input_file.split('.')[0]
        if not os.path.exists(f'{output_dir}{filename}.pdbqt'):
            input_file_path = ligand_pdbqt_dir + set_name + '/' + input_file
            # 执行命令
            command = f'gnina --out {output_dir}{filename}.pdbqt --receptor {receptor_file} --ligand {input_file_path} --center_x 22.945 --center_y -10.253 --center_z 14.962 --size_x 18.0 --size_y 27.0 --size_z 16.5 --device 0'
            os.system(command)
            

## Do KarmaDock on a batch of ligands and one protein

In [None]:
import os
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

protein_name, pdb_name = 'PARP1', '6nrh'
receptor_file = f'/home1/yueming/Drug_Discovery/Baselines/GIGN-main/GIGN/data/{protein_name}/protein/{pdb_name}_protein.pdb'
ligand_smi_dir = f'/home1/yueming/Drug_Discovery/Baselines/GIGN-main/GIGN/data/{protein_name}/ligand/smi/'
output_root = f'/home1/yueming/Drug_Discovery/Baselines/GIGN-main/GIGN/data/{protein_name}/{pdb_name}/karmadock/docking_data/'
karmadock_path = '/home1/yueming/Drug_Discovery/Baselines/KarmaDock-main/'
input_sets = ['test', 'train_1', 'train_2', 'train_3', 'train_4', 'train_5']
pocket_center_dict = {'pocket_active': [22.945, -10.253, 14.962], 'pocket_1': [26.375,-8.04,9.11], 'pocket_2': [38.309,4.975,18.243], 'pocket_3': [29.747,-8.243,32.118], 'protein_center': [25.61,-4.054,18.329]}
# 设置工作目录为包含 inference.py 文件的目录
os.chdir(karmadock_path + 'utils')
# 循环执行命令
for set_name in tqdm(input_sets):
    output_dir = output_root + set_name + '/'
    input_file = ligand_smi_dir + f'{set_name}.smi'
    for key, value in pocket_center_dict.items():
        output_path = output_dir + key
        if not os.path.exists(f'{output_path}/1'):
            os.makedirs(output_path, exist_ok=True)
            # 执行命令
            command = f'python -u virtual_screening.py --ligand_smi {input_file} --protein_file {receptor_file} --crystal_ligand_file "{value}" --model_file ../trained_models/karmadock_screening.pkl --out_dir {output_path} --batch_size 64 --random_seed 2023'
            print(command)
            os.system(command)


## Preprocessing docking results to mols

In [None]:
# conda activate rdkit
import os, re
import pickle
from rdkit import Chem
import pandas as pd
from tqdm import tqdm
import pymol
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')

def generate_pocket(dataset_path_template, dock_software_list, input_ligand_format_dict, protein_path, distance=5, save_states=1, generated_data_folder_name='pocket_complex', reverse=False):
    for dock_software in dock_software_list:
        print(f'Generating pockets for {dock_software} results')
        input_ligand_format = input_ligand_format_dict[dock_software]
        data_dir = dataset_path_template.format(dock_software, 'docking_data')
        save_dir = dataset_path_template.format(dock_software, generated_data_folder_name)
        os.makedirs(save_dir, exist_ok=True)
        complex_id = os.listdir(data_dir)
        complex_id.sort(reverse=reverse)
        for file in complex_id:
            lig_path_list, pocket_save_path_list = [], []
            if os.path.isdir(os.path.join(data_dir, file)):
                cid = file
                if dock_software == 'diffdock':
                    ranking_files = os.listdir(os.path.join(data_dir, cid))
                    pattern = r'rank(\d+)_'
                    for ranking_file in ranking_files:
                        match = re.search(pattern, ranking_file)
                        if match:
                            matched_number = match.group(1)
                            rank = int(matched_number)
                            if rank <= save_states:
                                lig_path_list.append(os.path.join(data_dir, cid, ranking_file))
                                pocket_save_path_list.append(os.path.join(save_dir, cid, f'{cid}_{dock_software}_rank{rank}_pocket_{distance}A.pdb'))
                elif dock_software == 'tankbind': # just one state
                    p2rank_pocket_file_list = os.listdir(os.path.join(data_dir, cid))
                    p2rank_pocket_name_list = [p2rank_pocket_file.split('.')[0] for p2rank_pocket_file in p2rank_pocket_file_list]
                    lig_path_list = [os.path.join(data_dir, cid, p2rank_pocket_file) for p2rank_pocket_file in p2rank_pocket_file_list]
                    pocket_save_path_list = [os.path.join(save_dir, p2rank_pocket_name, cid, f'{cid}_{dock_software}_{p2rank_pocket_name}_pocket_{distance}A.pdb')for p2rank_pocket_name in p2rank_pocket_name_list]
                elif dock_software == 'karmadock':
                    pocket_name = file
                    file_path_list = os.listdir(os.path.join(data_dir, pocket_name))
                    cid_list = [_.split('_')[0] for _ in file_path_list if _.split('.')[-1]==input_ligand_format]
                    suffix_list = [_.split('_', 1)[-1].split('.')[0] for _ in file_path_list if _.split('.')[-1]==input_ligand_format]
                    lig_path_list = [os.path.join(data_dir, pocket_name, _) for _ in file_path_list if _.split('.')[-1]==input_ligand_format]
                    pocket_save_path_list = []
                    for cid, suffix in zip(cid_list, suffix_list):
                        pocket_save_path_list.append(os.path.join(save_dir, pocket_name, cid, f'{cid}_{dock_software}_{pocket_name}_{suffix}_pocket_{distance}A.pdb'))
                elif dock_software == 'vina':
                    pocket_name = file
                    file_path_list = os.listdir(os.path.join(data_dir, pocket_name))
                    for ligand_result_file in file_path_list:
                        file_suffix = ligand_result_file.split('.')[-1]
                        if file_suffix == input_ligand_format:
                            cid = ligand_result_file.split('.')[0]
                            for save_state in range(save_states):
                                lig_path_list.append(os.path.join(data_dir, pocket_name, ligand_result_file))
                                pocket_save_path_list.append(os.path.join(save_dir, pocket_name, cid, f'{cid}_{dock_software}_{pocket_name}_pose{save_state}_pocket_{distance}A.pdb'))
            else:
                if file.split('.')[-1] == input_ligand_format:
                    cid = file.split('.')[0]
                    lig_path_list = [os.path.join(data_dir, file) for save_state in range(save_states)]
                    pocket_save_path_list = [os.path.join(save_dir, cid, f'{cid}_{dock_software}_pose{save_state+1}_pocket_{distance}A.pdb') for save_state in range(save_states)]

            protein_name = protein_path.split('/')[-1].split('.')[0]
            for lig_path, pocket_path in zip(lig_path_list, pocket_save_path_list):
                if not os.path.exists(pocket_path) or recreate:
                    pymol.cmd.delete('all')
                    pymol.cmd.load(lig_path)
                    pymol.cmd.remove('hydrogens')
                    pymol.cmd.load(protein_path)
                    pymol.cmd.remove('resn HOH')
                    object_list = pymol.cmd.get_object_list()  # 获取所有对象列表
                    try:
                        obj_ligand, obj_protein = object_list[0], object_list[1]
                    except:
                        print(f'No docking data found for {lig_path}') # if no molecule in file
                        continue
                    pattern = r'_pose(\d+)_'
                    match = re.search(pattern, pocket_path)
                    if match:
                        matched_number = match.group(1)
                        state = int(matched_number)
                    else:
                        state = 1
                
                    pymol.cmd.create(f"state_{state}", obj_ligand, source_state=state, target_state=1) # target_state is the state of the new created ones
                    pymol.cmd.select('Pocket', f'byres {protein_name} within {distance} of state_{state}')
                    os.makedirs(os.path.dirname(pocket_path), exist_ok=True)
                    pymol.cmd.save(pocket_path, 'Pocket')


def split_ligand_to_pdb(ligand_input_path, lig_save_path, save_state=1, generated_data_folder_name='pocket_complex'):
    pymol.cmd.delete('all')
    pdb_name = ligand_input_path.split('/')[-1].split('.')[0]
    pymol.cmd.load(ligand_input_path)
    pymol.cmd.remove('hydrogens')
    object_list = pymol.cmd.get_object_list()  # 获取所有对象列表
    try:
        obj_ligand = object_list[0]
    except:
        print(f'No molecule found in: {ligand_input_path}')
        return 0
    total_states = pymol.cmd.count_states(obj_ligand)
    if save_state <= total_states:
        save_path = lig_save_path.rsplit('.', 1)[0] + f'_pose{save_state}.pdb'
        save_path = save_path.replace('docking_data', generated_data_folder_name)
        if not os.path.exists(save_path) or recreate:
            pymol.cmd.create(f"state_{save_state}", obj_ligand, save_state, 1)
            os.makedirs(os.path.dirname(save_path), exist_ok=True)
            pymol.cmd.save(save_path, f"state_{save_state}")
    else:
        print(f'save_state {save_state} is greater than the total_states {total_states}')
                    
                
def generate_complex(dataset_path_template, dock_software_list, input_ligand_format_dict, distance=5, save_states=1, generated_data_folder_name='pocket_complex', reverse=False):
    for dock_software in dock_software_list:
        print(f'Generating complexes for {dock_software} results')
        input_ligand_format = input_ligand_format_dict[dock_software]
        data_dir = dataset_path_template.format(dock_software, 'docking_data')
        save_dir = dataset_path_template.format(dock_software, generated_data_folder_name)
        os.makedirs(save_dir, exist_ok=True)
        file_list = os.listdir(data_dir)
        file_list.sort(reverse=reverse)
        for file in file_list:
            lig_path_list, lig_save_path_list, pocket_path_list, complex_save_path_list = [], [], [], []
            if os.path.isdir(os.path.join(data_dir, file)):
                cid = file
                if dock_software == 'diffdock':
                    ranking_files = os.listdir(os.path.join(data_dir, cid))
                    pattern = r'rank(\d+)_'
                    for ranking_file in ranking_files:
                        match = re.search(pattern, ranking_file)
                        if match:
                            matched_number = match.group(1)
                            rank = int(matched_number)
                            if rank <= save_states:
                                lig_path_list.append(os.path.join(data_dir, cid, ranking_file))
                                pocket_path_list.append(os.path.join(save_dir, cid, f'{cid}_{dock_software}_rank{rank}_pocket_{distance}A.pdb'))
                                complex_save_path_list.append(os.path.join(save_dir, cid, f'{cid}_{dock_software}_rank{rank}_complex_{distance}A.rdkit'))
                    lig_save_path_list = lig_path_list
                elif dock_software == 'tankbind': # just one state
                    p2rank_pocket_file_list = os.listdir(os.path.join(data_dir, cid))
                    p2rank_pocket_name_list = [p2rank_pocket_file.split('.')[0] for p2rank_pocket_file in p2rank_pocket_file_list]
                    lig_path_list = [os.path.join(data_dir, cid, p2rank_pocket_file) for p2rank_pocket_file in p2rank_pocket_file_list]
                    lig_save_path_list = [os.path.join(save_dir, p2rank_pocket_name, cid, p2rank_pocket_name) for p2rank_pocket_name in p2rank_pocket_name_list]
                    pocket_path_list = [os.path.join(save_dir, p2rank_pocket_name, cid, f'{cid}_{dock_software}_{p2rank_pocket_name}_pocket_{distance}A.pdb')for p2rank_pocket_name in p2rank_pocket_name_list]
                    complex_save_path_list = [os.path.join(save_dir, p2rank_pocket_name, cid, f'{cid}_{dock_software}_{p2rank_pocket_name}_complex_{distance}A.rdkit')for p2rank_pocket_name in p2rank_pocket_name_list]
                elif dock_software == 'karmadock':
                    pocket_name = file
                    file_path_list = os.listdir(os.path.join(data_dir, pocket_name))
                    cid_list = [_.split('_')[0] for _ in file_path_list if _.split('.')[-1]==input_ligand_format]
                    suffix_list = [_.split('_', 1)[-1].split('.')[0] for _ in file_path_list if _.split('.')[-1]==input_ligand_format]
                    lig_path_list = [os.path.join(data_dir, pocket_name, _) for _ in file_path_list if _.split('.')[-1]==input_ligand_format]
                    lig_save_path_list = [os.path.join(save_dir, pocket_name, _.split('_')[0], _) for _ in file_path_list if _.split('.')[-1]==input_ligand_format]
                    for cid, suffix in zip(cid_list, suffix_list):
                        pocket_path_list.append(os.path.join(save_dir, pocket_name, cid, f'{cid}_{dock_software}_{pocket_name}_{suffix}_pocket_{distance}A.pdb'))
                        complex_save_path_list.append(os.path.join(save_dir, pocket_name, cid, f'{cid}_{dock_software}_{pocket_name}_{suffix}_complex_{distance}A.rdkit'))      
                elif dock_software == 'vina':
                    pocket_name = file
                    file_path_list = os.listdir(os.path.join(data_dir, pocket_name))
                    for ligand_result_file in file_path_list:
                        file_suffix = ligand_result_file.split('.')[-1]
                        if file_suffix == input_ligand_format:
                            cid = ligand_result_file.split('.')[0]
                            for save_state in range(save_states):
                                lig_path_list.append(os.path.join(data_dir, pocket_name, ligand_result_file))
                                lig_save_path_list.append(os.path.join(save_dir, pocket_name, cid, f'{cid}_{dock_software}_{pocket_name}_pose{save_state}.{input_ligand_format}'))
                                pocket_path_list.append(os.path.join(save_dir, pocket_name, cid, f'{cid}_{dock_software}_{pocket_name}_pose{save_state}_pocket_{distance}A.pdb'))
                                complex_save_path_list.append(os.path.join(save_dir, pocket_name, cid, f'{cid}_{dock_software}_{pocket_name}_pose{save_state}_complex_{distance}A.rdkit'))      
            else:
                if file.split('.')[-1] == input_ligand_format:
                    cid = file.split('.')[0]
                    lig_path_list = [os.path.join(data_dir, file) for save_state in range(save_states)]
                    lig_save_path_list = [os.path.join(save_dir, cid, file) for save_state in range(save_states)]
                    pocket_path_list = [os.path.join(save_dir, cid, f'{cid}_{dock_software}_pose{save_state+1}_pocket_{distance}A.pdb') for save_state in range(save_states)]
                    complex_save_path_list = [os.path.join(save_dir, cid, f'{cid}_{dock_software}_pose{save_state+1}_complex_{distance}A.rdkit') for save_state in range(save_states)]

            for ligand_input_path, lig_save_path, pocket_path, complex_save_path in zip(lig_path_list, lig_save_path_list, pocket_path_list, complex_save_path_list):
                if not os.path.exists(complex_save_path) or recreate:
                    pattern = r'_pose(\d+)_'
                    match = re.search(pattern, pocket_path)
                    if match:
                        matched_number = match.group(1)
                        pose = int(matched_number)
                    else:
                        pose = 1
                    split_ligand_to_pdb(ligand_input_path, lig_save_path, save_state=pose, generated_data_folder_name=generated_data_folder_name)

                    ligand_path = lig_save_path.rsplit('.', 1)[0] + f'_pose{pose}.pdb' # one state of the ligand
                    ligand_path = ligand_path.replace('docking_data', generated_data_folder_name) # where the "split_ligand_to_pdb" function save the ligand states
                    if not os.path.exists(ligand_path):
                        print(f'Not found pose: {ligand_path}')
                        continue

                    ligand_file = Chem.MolFromPDBFile(ligand_path, removeHs=True)
                    if ligand_file == None:
                        print(f"Unable to process ligand: {ligand_path}")
                        continue

                    pocket_file = Chem.MolFromPDBFile(pocket_path, removeHs=True)
                    if pocket_file == None:
                        print(f"Unable to process pocket: {pocket_path}")
                        continue

                    complex = (ligand_file, pocket_file)
                    os.makedirs(os.path.dirname(complex_save_path), exist_ok=True)
                    with open(complex_save_path, 'wb') as f:
                        pickle.dump(complex, f)
    
    
generated_data_folder_name, recreate, reverse = 'pocket_complex', False, False
protein_name, assay_type, pdb_name, pose_num, distance = 'PARP1', 'IC50', '6nrh', 9, 5
dock_software_list = [['karmadock', 'diffdock', 'tankbind', 'gnina', 'qvina_w', 'vina'][5]] # qvina_w need to generate subset by subset, in avoid of corrupt
input_ligand_format_dict = {'karmadock': 'sdf', 'diffdock': 'sdf', 'tankbind': 'sdf', 'gnina': 'pdbqt', 'qvina_w': 'pdbqt', 'vina': 'pdbqt'}
input_protein_format_dict = {'karmadock': 'pdb', 'diffdock': 'pdb', 'tankbind': 'pdb', 'gnina': 'pdb', 'qvina_w': 'pdbqt', 'vina': 'pdbqt'}
protein_path = f'/home1/yueming/Drug_Discovery/Baselines/GIGN-main/GIGN/data/{protein_name}/protein/{pdb_name}_protein.pdbqt'
data_root = f'/home1/yueming/Drug_Discovery/Baselines/GIGN-main/GIGN/data/{protein_name}/{pdb_name}/'
activity_root = f'/home1/yueming/Drug_Discovery/Baselines/GIGN-main/GIGN/data/{protein_name}/activity'
set_list = ['test', 'train_1', 'train_2', 'train_3', 'train_4', 'train_5']
for dataset in tqdm(set_list):
    print(f'Processing {dataset} ...')
    data_df = pd.read_csv(os.path.join(activity_root, assay_type, f'{dataset}.csv'))
    dataset_path_template = data_root + '{}/{}/' + dataset + '/'

    ## generate pocket within 12 Ångström around ligand 
    generate_pocket(dataset_path_template=dataset_path_template, dock_software_list=dock_software_list, input_ligand_format_dict=input_ligand_format_dict, protein_path=protein_path, distance=distance, save_states=pose_num, generated_data_folder_name=generated_data_folder_name, reverse=reverse)
    generate_complex(dataset_path_template=dataset_path_template, dock_software_list=dock_software_list, input_ligand_format_dict=input_ligand_format_dict, distance=distance, save_states=pose_num, generated_data_folder_name=generated_data_folder_name, reverse=reverse)
    

## Save DTIGN graphs for each ligand with bond features

In [None]:
# conda activate base
import os, re
import pandas as pd
import numpy as np
import pickle
from scipy.spatial import distance_matrix
import multiprocessing
from itertools import repeat
import networkx as nx
import torch 
from torch.utils.data import Dataset, DataLoader
from rdkit import Chem
from rdkit import RDLogger
from rdkit import Chem
from torch_geometric.data import Batch, Data
from tqdm import tqdm
import warnings
RDLogger.DisableLog('rdApp.*')
np.set_printoptions(threshold=np.inf)
warnings.filterwarnings('ignore')
from torch_geometric.data import Batch

# %%
def one_of_k_encoding(k, possible_values):
    if k not in possible_values:
        raise ValueError(f"{k} is not a valid value in {possible_values}")
    return [k == e for e in possible_values]


def one_of_k_encoding_unk(x, allowable_set):
    if x not in allowable_set:
        x = allowable_set[-1]
    return list(map(lambda s: x == s, allowable_set))


def atom_features(mol, graph, atom_symbols=['C', 'N', 'O', 'S', 'F', 'P', 'Cl', 'Br', 'I'], explicit_H=True):

    for atom in mol.GetAtoms():
        results = one_of_k_encoding_unk(atom.GetSymbol(), atom_symbols + ['Unknown']) + \
                one_of_k_encoding_unk(atom.GetDegree(),[0, 1, 2, 3, 4, 5, 6]) + \
                one_of_k_encoding_unk(atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5, 6]) + \
                one_of_k_encoding_unk(atom.GetHybridization(), [
                    Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2,
                    Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.
                                        SP3D, Chem.rdchem.HybridizationType.SP3D2
                    ]) + [atom.GetIsAromatic()]
        # In case of explicit hydrogen(QM8, QM9), avoid calling `GetTotalNumHs`
        if explicit_H:
            results = results + one_of_k_encoding_unk(atom.GetTotalNumHs(),
                                                    [0, 1, 2, 3, 4])

        atom_feats = np.array(results).astype(np.float32)

        graph.add_node(atom.GetIdx(), feats=torch.from_numpy(atom_feats))

def bond_features(bond, use_chirality=True):
    bt = bond.GetBondType()
    bond_feats = [
        bt == Chem.rdchem.BondType.SINGLE, bt == Chem.rdchem.BondType.DOUBLE,
        bt == Chem.rdchem.BondType.TRIPLE, bt == Chem.rdchem.BondType.AROMATIC,
        bond.GetIsConjugated(),
        bond.IsInRing()
    ]
    if use_chirality:
        bond_feats = bond_feats + one_of_k_encoding_unk(
            str(bond.GetStereo()),
            ["STEREONONE", "STEREOANY", "STEREOZ", "STEREOE"])
    return np.array(bond_feats)
        
def get_edge_index(mol, graph):
    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        bond_feats = bond_features(bond)
        graph.add_edge(i, j, weight=bond_feats)

def mol2graph(mol):
    graph = nx.Graph()
    atom_features(mol, graph)
    get_edge_index(mol, graph)

    graph = graph.to_directed()
    x = torch.stack([feats['feats'] for n, feats in graph.nodes(data=True)])
    x_bond = torch.tensor([graph[u][v]['weight'] for u, v in graph.edges()], dtype=torch.float32)
    if not graph.edges(data=False):
        return [], [], [], True
    edge_index = torch.stack([torch.LongTensor((u, v)) for u, v in graph.edges(data=False)]).T

    return x, x_bond, edge_index, False

def inter_graph(ligand, pocket, dis_threshold = 5.):
    atom_num_l = ligand.GetNumAtoms()
    atom_num_p = pocket.GetNumAtoms()

    graph_inter = nx.Graph()
    pos_l = ligand.GetConformers()[0].GetPositions()
    pos_p = pocket.GetConformers()[0].GetPositions()
    dis_matrix = distance_matrix(pos_l, pos_p)
    node_idx = np.where(dis_matrix < dis_threshold)
    for i, j in zip(node_idx[0], node_idx[1]):
        graph_inter.add_edge(i, j+atom_num_l) 

    graph_inter = graph_inter.to_directed()
    edge_index_inter = torch.stack([torch.LongTensor((u, v)) for u, v in graph_inter.edges(data=False)]).T

    return edge_index_inter

# %%
def mols2graphs(complex_path_list, label, save_path, dis_threshold):
    data_list = []
    fail_path = []
    for i, complex_path in enumerate(complex_path_list):
        if os.path.exists(complex_path):
            try:
                with open(complex_path, 'rb') as f:
                    ligand, pocket = pickle.load(f)
            except EOFError:
                print(f'Error: Ran out of input when unpickling. Check the file contents: {complex_path}')
                continue
        else:
            print('Complex file not found:', complex_path)
            fail_path.append(complex_path)
            continue

        atom_num_l = ligand.GetNumAtoms()
        atom_num_p = pocket.GetNumAtoms()

        pos_l = torch.FloatTensor(ligand.GetConformers()[0].GetPositions())
        pos_p = torch.FloatTensor(pocket.GetConformers()[0].GetPositions())
        x_l, x_l_bond, edge_index_l, fail_l = mol2graph(ligand)
        x_p, x_p_bond, edge_index_p, fail_p = mol2graph(pocket)
        if fail_l or fail_p:
            print('Failed to read complex file:', complex_path)
            fail_path.append(complex_path)
            continue

        x = torch.cat([x_l, x_p], dim=0)
        x_bond = torch.cat([x_l_bond, x_p_bond], dim=0)
        edge_index_intra = torch.cat([edge_index_l, edge_index_p+atom_num_l], dim=-1)
        try:
            edge_index_inter = inter_graph(ligand, pocket, dis_threshold=dis_threshold)
        except:
            print('Failed to read complex edges:', complex_path)
            fail_path.append(complex_path)
            continue
            
        y = torch.FloatTensor([label])
        pos = torch.concat([pos_l, pos_p], dim=0)
        split = torch.cat([torch.zeros((atom_num_l, )), torch.ones((atom_num_p,))], dim=0)
        dock_software = complex_path.split('/')[-1].split('_')[1]
        pattern = f"{dock_software}_(.*?)_complex"
        match = re.search(pattern, complex_path)
        pocket_or_pose = ''
        if match:
            pocket_or_pose = match.group(1)
        
        data = Data(x=x, x_bond=x_bond, edge_index_intra=edge_index_intra, edge_index_inter=edge_index_inter, y=y, pos=pos, dock_software=dock_software, pocket_or_pose=pocket_or_pose, split=split)
        data_list.append(data)
        
    if len(fail_path) == len(complex_path_list):
        return complex_path_list
    else:
        merged_data = Batch.from_data_list(data_list)
        if not (skip_exist and os.path.exists(save_path)):
            os.makedirs(os.path.dirname(save_path), exist_ok=True)
            torch.save(merged_data, save_path)
        return fail_path

# %%
class PLIDataLoader(DataLoader):
    def __init__(self, data, **kwargs):
        super().__init__(data, collate_fn=data.collate_fn, **kwargs)

class GraphDataset(Dataset):
    """
    This class is used for generating graph objects using multi process
    """
    def __init__(self, dataset_path_template, data_df, dock_software_list, save_dir, run_datafold=None, dis_threshold=5, num_pose=1, graph_type='Graph_GIGN', assay_type='pIC50', num_process=8, create=False, addition=False):
        self.dataset_path_template = dataset_path_template
        self.dock_software_list = dock_software_list
        self.save_dir = save_dir
        self.data_df = data_df
        self.dis_threshold = dis_threshold
        self.num_pose = num_pose
        self.graph_type = graph_type
        self.create = create
        self.graph_paths = None
        self.compound_ids = None
        self.assay_type = assay_type
        self.num_process = num_process
        self.mean, self.std = 0, 1
        self.run_datafold = run_datafold
        self.addition = addition
        self._pre_process()

    def _pre_process(self):
        dataset_path_template = self.dataset_path_template
        dock_software_list = self.dock_software_list
        data_df = self.data_df
        graph_type = self.graph_type
        save_dir = self.save_dir

        complex_path_list, compound_id_list, pIC50_list, score_list, graph_path_list, dis_threshold_list = [], [], [], [], [], []
        not_found_list, not_found_flag = [], True
        for i, row in data_df.iterrows():
            cid, pIC50 = 'canSAR' + str(row['COMPOUND']), float(row['p' + self.assay_type]) # e.g. pIC50
            complex_path_list_cid, compound_id_list_cid, score_list_cid = [], [], []
            for dock_software in dock_software_list:
                data_dir = dataset_path_template.format(dock_software, 'pocket_complex')
                file_list = os.listdir(data_dir)
                if dock_software in ['karmadock', 'tankbind', 'vina']: # cid files in pocket folders
                    for pocket in file_list:
                        cid_folder = os.path.join(data_dir, pocket, cid)
                        if os.path.exists(cid_folder):
                            cid_folder_files = os.listdir(cid_folder)
                            rdkit_file_list = [f'{cid}_vina_{pocket}_pose{pose}_complex_{self.dis_threshold}A.rdkit' for pose in range(self.num_pose)] if dock_software == 'vina' else [file for file in cid_folder_files if file.split('.')[-1]=='rdkit']
                            if rdkit_file_list:
                                for rdkit_file in rdkit_file_list:
                                    complex_path = os.path.join(data_dir, pocket, cid, rdkit_file)
                                    if os.path.exists(complex_path):
                                        not_found_flag = False
                                        complex_path_list_cid.append(complex_path)
                else:
                    cid_folder = os.path.join(data_dir, cid)
                    if os.path.exists(cid_folder):
                        cid_folder_files = os.listdir(cid_folder)
                        rdkit_file_list = [file for file in cid_folder_files if file.split('.')[-1]=='rdkit']
                        if rdkit_file_list:
                            not_found_flag = False
                            for rdkit_file in rdkit_file_list:
                                complex_path = os.path.join(data_dir, cid, rdkit_file)
                                complex_path_list_cid.append(complex_path)
            
            graph_path = os.path.join(save_dir, f"{cid}_{graph_type}_{self.dis_threshold}A.pyg")
            if not_found_flag:
                not_found_list.append(graph_path)
            else:
                graph_path_list.append(graph_path)
                complex_path_list.append(complex_path_list_cid)
                compound_id_list.append(cid)
                pIC50_list.append(pIC50)
                dis_threshold_list.append(self.dis_threshold)
                
        self.mean, self.std = np.mean(pIC50_list), np.std(pIC50_list)
        if self.create:
            print('Generate complex graph...')
            # multi-thread processing
            pool = multiprocessing.Pool(self.num_process)
            for complex_path, pIC50 ,graph_path, dis_threshold in tqdm(zip(complex_path_list, pIC50_list, graph_path_list, dis_threshold_list)):
                not_found_path = mols2graphs(complex_path, pIC50 ,graph_path, dis_threshold)
                if len(not_found_path) == len(complex_path):
                    not_found_list.append(graph_path)
            not_found_save_path = save_dir + f'/{self.graph_type}_not_found_list.pkl'
            os.makedirs(os.path.dirname(not_found_save_path), exist_ok=True)
            if self.addition:
                print(f'Adding not found graph paths into {not_found_save_path}...')
                with open(not_found_save_path, 'rb') as f:
                    existing_not_found_list = pickle.load(f)
                not_found_list = existing_not_found_list + not_found_list
            with open(not_found_save_path, 'wb') as f:
                pickle.dump(not_found_list, f)
            pool.close()
            pool.join()
        
        with open(save_dir + f'/{self.graph_type}_not_found_list.pkl', 'rb') as f:
            not_found_list = pickle.load(f)
        not_found_list = [path.replace(self.run_datafold, '.') for path in not_found_list]
        graph_path_list = [path.replace(self.run_datafold, '.') for path in graph_path_list]
        self.compound_ids = compound_id_list
        self.graph_paths = [x for x in graph_path_list if x not in not_found_list] # excluding unsuccessful creation
        
    def __getitem__(self, idx):
        data = torch.load(self.graph_paths[idx])
        match = re.search(r'canSAR(\d+)_', self.graph_paths[idx])
        cid = match.group(0)
        data['idx'] = cid[:-1]
        return data

    def collate_fn(self, batch):
        return Batch.from_data_list(batch)

    def __len__(self):
        return len(self.graph_paths)

    
skip_exist, graph_type = True, 'Graph_DTIGN'
protein_name, assay_type, pdb_name, pose_num, dis_threshold = 'PARP1', 'IC50', '6nrh', 3, 5
dock_software_list = [['karmadock', 'diffdock', 'tankbind', 'gnina', 'qvina_w', 'vina'][5]] # qvina_w need to generate subset by subset, in avoid of corrupt
data_root = f'/home1/yueming/Drug_Discovery/Baselines/GIGN-main/GIGN/data/{protein_name}/{pdb_name}/'
activity_root = f'/home1/yueming/Drug_Discovery/Baselines/GIGN-main/GIGN/data/{protein_name}/activity'
run_datafold = '/home1/yueming/Drug_Discovery/Baselines/GIGN-main/GIGN'
set_list = ['test', 'train_1', 'train_2', 'train_3', 'train_4', 'train_5']
dock_software_string = '_'.join(dock_software_list)
generated_data_folder_name = f'{graph_type}_{dock_software_string}_pose{pose_num}_{dis_threshold}A'
for set_name in set_list:
    data_df = pd.read_csv(os.path.join(activity_root, assay_type, f'{set_name}.csv'))
    dataset_path_template = data_root + '{}/{}/' + set_name + '/'
    save_dir = data_root + f'{generated_data_folder_name}/' + set_name + '/'
    dataset = GraphDataset(dataset_path_template, data_df, dock_software_list, save_dir, run_datafold=run_datafold, graph_type=graph_type, assay_type=assay_type, dis_threshold=dis_threshold, create=True)
    print('Dataset size:', len(dataset))
    data_loader = PLIDataLoader(dataset, batch_size=256, shuffle=True, num_workers=4)
    for data in data_loader:
        # print(data) --> DataBatch(x=[2481, 35], y=[8], pos=[2481, 3], edge_index_intra=[2, 4884], edge_index_inter=[2, 4456], split=[2481], pocket=[8], batch=[2481], ptr=[9])
        data, pocket, idx, label, software = data, data.pocket_or_pose, data.idx, data.y, data.dock_software
        print(f'Loading {len(pocket)} data successfully')

## Screen confident predictions from collected CHEMBL database

In [None]:
import os
import pandas as pd
from tqdm import tqdm

uncertainty_lower_bound, uncertainty_upper_bound, dropout_times, seed = 0, 0.047, 10, 294
# 指定目录路径
directory = f"/home1/yueming/Drug_Discovery/Baselines/GIGN-main/GIGN/baseline/PARP1/infer_AFP/dropout_times={dropout_times}-seed={reed}/"
source_dir = '/home1/yueming/Drug_Discovery/Datasets/'
dropout_list = [0.1, 0.2, 0.3, 0.4, 0.5] # , 0.6, 0.7, 0.8, 0.9
dropout_column_list = [f'uncertainty-dropout={dropout}' for dropout in dropout_list]
afse_column_list = ['AFSE_uncertainty']
uncertainty_column_list = afse_column_list + dropout_column_list
min_max_values = [0.20327282, 0.4505804, 0.003099999, 0.23411462, 0.00621228, 0.49756283, 0.010164783, 1.2705135, 0.00972938, 2.0085182, 0.027138822, 1.8844687]

def min_max_scaling(column, min_value, max_value):
    return (column - min_value) / (max_value - min_value)

def sum_up_uncertainty(data, uncertainty_column_list, min_max_values, file_path):
    begin, counter = True, 0
    for i, uncertainty_column in enumerate(uncertainty_column_list):
        if uncertainty_column not in data.columns:
            # The column exists in the DataFrame
            print(f"Column {uncertainty_column} is missing in {file_path}")
            continue
        min_value, max_value = min_max_values[2*i], min_max_values[2*i+1]
        if begin:
            z = min_max_scaling(data[uncertainty_column].values, min_value, max_value)
            begin = False
        else:
            z += min_max_scaling(data[uncertainty_column].values, min_value, max_value)
        counter += 1
    z /= counter
    data['sum_up_uncertainty'] = z

def search_and_add_assays(selected_rows, target, assay):
    source_path = f'{source_dir}/{target}/{assay}/{target}_{assay}_all.csv'
    source_df = pd.read_csv(source_path)
    source_df = source_df.rename(columns={'SMILES': 'smiles'})
    selected_rows = selected_rows.merge(source_df, how='left', on='smiles')
    return selected_rows
    
# 创建一个空的DataFrame来存储结果
result_df = pd.DataFrame()
# 循环遍历目录中的所有.csv文件
for filename in tqdm(os.listdir(directory)):
    if filename.endswith(".csv"):
        last_names = filename.split('/')[-1][:-8]
        target = last_names.split('-')[0]
        assay = last_names[len(target) + 1:]
        file_path = os.path.join(directory, filename)
        df = pd.read_csv(file_path)
        sum_up_uncertainty(df, uncertainty_column_list, min_max_values, file_path)
        # 选择"uncertainty"列小于或等于uncerternty_threshold的行
        selected_rows = df[df["sum_up_uncertainty"] <= uncertainty_upper_bound]
        selected_rows = selected_rows[selected_rows["sum_up_uncertainty"] >= uncertainty_lower_bound]
        if len(selected_rows) > 0:
            selected_rows = search_and_add_assays(selected_rows, target, assay)
            # 将选定的行添加到结果DataFrame
            result_df = pd.concat([result_df, selected_rows], ignore_index=True)

# 保存结果DataFrame为CSV文件
result_file_path = f"/home1/yueming/Drug_Discovery/Baselines/GIGN-main/GIGN/baseline/PARP1/infer_AFP/dropout_times={dropout_times}-seed={seed}-uncertainty_range={uncerternty_lower_bound}~{uncerternty_upper_bound}.csv"
result_df.to_csv(result_file_path, index=False)
print(f'Screen out {len(result_df)} samples in total.')

## Convert Isomeric SMILES to Canonical SMILES

In [None]:
from rdkit import Chem

# 输入Isomeric SMILES
isomeric_smiles = "c1cc2c(c(c1)NC(=O)CN3CCN(CC3)C(=O)[C@@H]4[C@H]([C@H]([C@@H](O4)n5cnc6c5ncnc6N)O)O)CNC2=O"

# 使用RDKit将Isomeric SMILES转换为Molecule对象
mol = Chem.MolFromSmiles(isomeric_smiles)

if mol is not None:
    # 将Molecule对象转换为Canonical SMILES
    canonical_smiles = Chem.MolToSmiles(mol, isomericSmiles=False)  # 使用isomericSmiles=False以获取Canonical SMILES

    # 打印Canonical SMILES
    print("Canonical SMILES:", canonical_smiles)
else:
    print("Invalid SMILES input")
