In [None]:
#@title Install Dependencies
%%capture
!pip install -q e3nn==0.5.1 fair-esm==2.0.0 networkx==3.2.1 pybind11==2.12.0 rdkit==2023.9.6 requests==2.32.3 scikit-learn==1.5.0 torch-geometric==2.2.0 prody pydantic==2.10.6 gradio==4.44.1
!pip install -q pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.4.0+cu121.html

In [None]:
#@title Clone DiffDock
%%capture
!git clone https://github.com/gcorso/DiffDock.git
%cd DiffDock
!gdown --fuzzy "https://drive.google.com/file/d/1CFuI1XEJpEm-jTGRRmD7FjGhJhMIE3vJ/view?usp=sharing"
!unzip diffdock_cache.zip

import esm
_ = esm.pretrained.esm2_t33_650M_UR50D()

In [None]:
#@title Gradio Interface

import os
import re
import shutil
import subprocess
import pickle
import zipfile
import gradio as gr
import pandas as pd

from rdkit import Chem

atom_term_compiled = re.compile(r'(ATOM|TER).*')
aa_chain_pos_compiled = re.compile(r'([A-Z]{3}).([a-zA-Z0-9])\s*(-?\d+)')
confidence_compiled = re.compile(r'_confidence(-?\d+\.\d+)\.sdf')
retreive_mdl_compiled = re.compile(r'MODEL\s+[0-9]+\s+((\n|.)*?)ENDMDL')
retrieve_line_without_na_compiled = re.compile(r'\n(ATOM|TER.).{13}(?!\s+(DT|DA|DC|DG|DI|A|U|C|G|I)).*')

atom_type_map = {'HD': 'H', 'HS': 'H',
                 'NA': 'N', 'NS': 'N',
                 'A' : 'C', 'G' : 'C', 'CG0': 'C', 'CG1': 'C', 'CG2': 'C', 'CG3': 'C', 'G0': 'C', 'G1': 'C', 'G2': 'C', 'G3': 'C',
                 'OA': 'O', 'OS': 'O',
                 'SA': 'S'}

def pdbqt_to_pdb(pdbqt_str: str):
    meet_end_of_chain = 0
    sub_re_map = {'HIS': re.compile(r'HID|HIP|HIE'),
                  'GLU': re.compile(r'GLH'),
                  'ASP': re.compile(r'ASH'),
                  'LYS': re.compile(r'LYN'),
                  'CYS': re.compile(r'CYM|CYX'),}
    last_chain, last_atomidx, last_resname, last_respos = None, None, None, None
    def map_pdbqt_line_to_pdb(line: str, idx: int):
        nonlocal meet_end_of_chain, last_chain, last_resname, last_respos, last_atomidx
        atom_type = line[77:].strip()
        if atom_type in atom_type_map:
            atom_type = atom_type_map[atom_type]
        chain = line[21]
        res_name = line[17:20].strip()
        res_pos = int(line[22:26].strip())
        final = []
        if chain != last_chain and last_chain is not None:
            ter = {'atom_idx': last_atomidx + 1,
                   'res_name': last_resname    ,
                   'chain'   : last_chain      ,
                   'res_pos' : last_respos     ,
                   }
            final.append(ter)
            meet_end_of_chain += 1
        atom_idx = idx + meet_end_of_chain
        line_dict = {'atom_idx'   : atom_idx                  ,
                     'atom_name'  : line[12:16].strip()       ,
                     'alt_id'     : line[16]                  ,
                     'res_name'   : res_name                  ,
                     'chain'      : chain                     ,
                     'res_pos'    : res_pos                   ,
                     'others'     : line[26:66]               ,  # skip partial charge
                     'atom_type'  : atom_type                 ,  # don't strip to keep spaces
                     }
        final.append(line_dict)
        last_chain, last_resname, last_respos, last_atomidx = chain, res_name, res_pos, atom_idx
        return final
    
    def convert_to_pdb_str(atom_data: list):
        lines = []
        format_str = "ATOM  {:5d} {:4s}{:1s}{:3s} {:1s}{:4d}{:40s}            {}\n"
        term_str   = "TER   {:5d}      {:3s} {:1s}{:4d}\n"
        for entry in atom_data:
            if len(entry) == 8:
                lines.append(format_str.format(
                    entry['atom_idx'], entry['atom_name'], entry['alt_id'], 
                    entry['res_name'], entry['chain'], entry['res_pos'],
                    entry['others'], entry['atom_type'],
                ))
            else:
                lines.append(term_str.format(
                    entry['atom_idx'], entry['res_name'], 
                    entry['chain'], entry['res_pos'],
                ))
        return ''.join(lines)
    
    for aa, re_comp in sub_re_map.items():
        pdbqt_str = re.sub(re_comp, aa, pdbqt_str)
    
    protein_data = [item for idx, line in enumerate(pdbqt_str.strip().splitlines()) if line.startswith('ATOM')
                    for item in map_pdbqt_line_to_pdb(line, idx)]
    final_ter = {'atom_idx': last_atomidx + 1,
                 'res_name': last_resname    ,
                 'chain'   : last_chain      ,
                 'res_pos' : last_respos     ,
                 }
    protein_data.append(final_ter)
    
    return protein_data, convert_to_pdb_str(protein_data)

class PDBEditor:
    def __init__(self, pdbqt_str: str | None=None):
        self.pdbqt_str = pdbqt_str
        self.pdb_chain_dict = {}
    
    def parse_pdb_text_to_dict(self, display_flex_dict = None):
        self.pdb_chain_dict = {}
        chain_aa_dict = {}
        for atom_term_line in re.finditer(atom_term_compiled, self.pdbqt_str):
            line = atom_term_line.group(0)
            aa, chain, aa_pos = re.search(aa_chain_pos_compiled, line).group(1,2,3)
            aa_pos = int(aa_pos)
            if chain not in self.pdb_chain_dict:
                self.pdb_chain_dict[chain] = {}
                chain_aa_dict[chain] = []
            aa_pos_dict = self.pdb_chain_dict[chain]
            if aa_pos not in aa_pos_dict:
                aa_pos_dict[aa_pos] = []
                chain_aa_dict[chain].append(aa)
            aa_pos_dict[aa_pos].append(line)
        for chain, aa_pos_dict in self.pdb_chain_dict.items():
            aa_cnt = len(aa_pos_dict)
            for pos, text_list in aa_pos_dict.items():
                aa_pos_dict[pos] = '\n'.join(text_list)
            self.pdb_chain_dict[chain] = pd.DataFrame.from_dict(aa_pos_dict, 'index')
            if display_flex_dict is None:
                self.pdb_chain_dict[chain]['Display'] = [True] * aa_cnt
                self.pdb_chain_dict[chain]['Flexible'] = [False] * aa_cnt
            else:
                self.pdb_chain_dict[chain]['Display'] = display_flex_dict[chain]['Display']
                self.pdb_chain_dict[chain]['Flexible'] = display_flex_dict[chain]['Flexible']
            self.pdb_chain_dict[chain]['AA_Name'] = chain_aa_dict[chain]
        
    def parse_logic(self, series: pd.Series, logic: str):
        def replace_expression(match):
            expr = match.group()
            negation = ''
            if expr.startswith('~'):
                negation = '~'
                expr = expr[1:]
            if '-' in expr:  # range
                start, end = map(int, expr.rsplit('-', 1))
                return f"{negation}((series.index >= {start}) & (series.index <= {end}))"
            else:  # single value
                return f"{negation}(series.index == {expr})"
        
        logic = logic.replace(' ', '').replace(',', '|')    # "," is the same as or "|"
        logic_eval = re.sub(r'~?-?\d+-\d+|~?\d+(?:,~?\d+)*', replace_expression, logic)
        
        try:
            result = pd.eval(logic_eval, local_dict={'series': series}, engine='python') # need to pass local_dict or else Nuitka compiled code won't work
            series[result] = True
        except:
            return None
        
        return series
    
    def update_display(self, chain: str, display_str: str | None):
        full_df = self.pdb_chain_dict[chain]
        if display_str is None:
            display_series = pd.Series([False] * len(full_df), list(full_df.index))
            full_df['Display'] = display_series
            return
        elif not display_str:
            display_series = pd.Series([True] * len(full_df), list(full_df.index))
            full_df['Display'] = display_series
            return
        display_series = pd.Series([False] * len(full_df), list(full_df.index)) # default to False
        result = self.parse_logic(display_series, display_str)
        if result is None:
            return f'Invalid syntax for chain {chain}.'
        full_df['Display'] = display_series
    
    def _condense_to_range(self, list_of_nums: list[int]):
        final_text = ''
        start = list_of_nums[0]
        end = list_of_nums[0]
        range_cnt = 1
        for num in list_of_nums[1:]:
            if num == start + range_cnt:
                end = num
                range_cnt += 1
            else:
                if start == end:
                    final_text += f'{start}|'
                else:
                    final_text += f'{start}-{end},'
                start = num
                end = num
                range_cnt = 1
        if start == end:
            final_text += f'{start}'
        else:
            final_text += f'{start}-{end}'
        return final_text
    
    def convert_to_range_text(self, chain: str):
        s = self.pdb_chain_dict[chain]['Display']
        mask = s == True
        displayed_pos = s[mask].index.to_list()
        if not displayed_pos:
            return f'~1-{max(s.index.to_list())}'
        if len(displayed_pos) == len(s):
            return ''   # empty text if everything is displayed
        return self._condense_to_range(displayed_pos)
    
    def convert_full_dict_to_text(self):
        protein_strs = []
        for df in self.pdb_chain_dict.values():
            string = '\n'.join(df[0].to_list())
            protein_strs.append(string)
        return '\n'.join(protein_strs)
    
    def convert_dict_to_pdb_text(self, return_scheme=False):
        pdbqt_str = ''
        cnt = 0
        for df in self.pdb_chain_dict.values():
            mask = df['Display'] == True
            string = '\n'.join(df[0][mask].to_list())
            if string:
                cnt += 1
            pdbqt_str += string
            if string:
                pdbqt_str += '\n'
        if return_scheme:
            if cnt > 1:
                scheme = 'chainindex'
            else:
                scheme = 'residueindex'
            return pdbqt_str, scheme
        return pdbqt_str
    
    def check_format_type(self):
        for chain_dict in self.pdb_chain_dict.values():
            line = chain_dict.iloc[0, 0]
            if line[70:76].strip():
                return 'pdbqt'
            else:
                return 'pdb'

def convert_row_to_string_float(csv_row_str: str):
    name, score = csv_row_str.strip().split(',')
    score = float(score)
    return name, score

class DiffDockInterface:
    def setup_interface(self):
        self.pdb_editor = None
        self.process = None
        self.curr_docking = False
        self.curr_dir = os.path.abspath('')
        self.dock_input_dir = os.path.join(self.curr_dir, 'dock_input')
        self.dock_output_dir = os.path.join(self.curr_dir, 'dock_output')
        self.zipped_dir = os.path.join(self.curr_dir, 'zipped_files')
        self.cache_dir = os.path.join(self.dock_output_dir, 'cache_files')
        os.makedirs(self.dock_input_dir, exist_ok=True)
        os.makedirs(self.dock_output_dir, exist_ok=True)
        os.makedirs(self.zipped_dir, exist_ok=True)
        os.makedirs(self.cache_dir, exist_ok=True)
        self.log_file = os.path.join(self.cache_dir, 'docking_log.log')
        with open(self.log_file, 'w') as f:
            ...
        self.csv_file = os.path.join(self.cache_dir, 'confidence_log.csv')
        with open(self.csv_file, 'w') as f:
            f.write('Name,Confidence\n')
        self.recorded_names = []
        with gr.Blocks(css='footer{display:none !important}') as Interface:
            gr.Markdown('<span style="font-size:25px; font-weight:bold; ">DiffDock Interface</span>')
            with gr.Row():
                with gr.Column():
                    gr.Markdown('<span style="font-size:20px; font-weight:bold; ">Protein PDB/CIF</span>')
                    protein_stat = gr.Textbox(label='Protein Status :')
                    protein_input = gr.File(file_count='single',
                                            file_types=['.pdb', '.mds'],
                                            label='Upload Protein PDB/CIF or MDS file')
                with gr.Column():
                    gr.Markdown('<span style="font-size:20px; font-weight:bold; ">DiffDock CSV</span>')
                    ligand_stat = gr.Textbox(label='Ligand Status :')
                    ligand_input = gr.File(file_count='single',
                                            file_types=['.csv'],
                                            label='Upload CSV File (prepared with Molecule Converter)')
            with gr.Row():
                display_text = gr.Textbox(label='Display',
                                          placeholder='A:1-15,20-35 B:15-20',
                                          interactive=True)
            with gr.Row():
                dock_progress = gr.Text(label='Progress',
                                        interactive=False,
                                        scale=4,
                                        max_lines=15,)
                with gr.Column(scale=1):
                    dock_button = gr.Button('Dock', interactive=False)
                    zip_checkbox = gr.Checkbox(label='Zip Result')
                    zip_progress = gr.Textbox(label='Progress', interactive=False)
                    zip_name = gr.Textbox(label='Zip File Name',
                                          placeholder='docked_result',
                                          interactive=True)
                    zipped_file_output = gr.File(label='Zipped file',
                                                 file_count='single',
                                                 file_types=['.zip'])
            dock_status = gr.Textbox(label='Docking Status :')
            energy_df = gr.DataFrame(headers=['Name', 'Confidence'],
                                     type='array',
                                     interactive=False)
            self.protein_loaded = False
            self.ligand_laoded = False
            protein_input.change(self.upload_protein_file, inputs=protein_input, outputs=[protein_stat, dock_button, display_text])
            ligand_input.change(self.upload_ligand_file, inputs=ligand_input, outputs=[ligand_stat, dock_button])
            dock_button.click(self.start_docking, inputs=display_text, outputs=None)
            
            Interface.load(self.read_docked_progress, None, dock_progress, every=1)
            Interface.load(self.check_curr_docking_status, None, [dock_status, dock_button], every=3)
            Interface.load(self.new_dataframe, None, energy_df, every=1)
            Interface.load(self.check_confidence, None, None, every=1)
            Interface.load(self.zip_docked_files, [zip_checkbox, zip_name],
                           [zipped_file_output, zip_checkbox, zip_progress], every=1)
            Interface.queue().launch(share=True)
            
    def upload_protein_file(self, file):
        display_update = gr.update()
        if file:
            if file.endswith('.pdb'):
                with open(file) as f:
                    pdb_str = f.read()
                self.pdb_editor = PDBEditor(pdb_str)
                self.pdb_editor.parse_pdb_text_to_dict()
                self.protein_loaded = True
                return 'Protein file loaded (pdb)', gr.update(interactive=self.protein_loaded & self.ligand_laoded), display_update
            else:
                with open(file, 'rb') as f:
                    mds_dict = pickle.load(f)
                if 'pdbqt_editor' in mds_dict:
                    setting_dict = mds_dict['pdbqt_editor']
                    self.pdb_editor = PDBEditor()
                    self.pdb_editor.pdb_chain_dict = setting_dict
                    if self.pdb_editor.check_format_type() == 'pdbqt':
                        full_str = self.pdb_editor.convert_full_dict_to_text()
                        display_flex_dict = {}
                        _, string = pdbqt_to_pdb(full_str)
                        for chain, df in self.pdb_editor.pdb_chain_dict.items():
                            display, flex = df['Display'], df['Flexible']
                            display_flex_dict[chain] = {'Display': display.to_list(), 'Flexible': flex.to_list()}
                        self.pdb_editor = PDBEditor(string)
                        self.pdb_editor.parse_pdb_text_to_dict(display_flex_dict)
                    display_strs = []
                    for chain in setting_dict:
                        r = self.pdb_editor.convert_to_range_text(chain)
                        if r:
                            display_strs.append(f'{chain}:{r}')
                    display_update = gr.update(value=' '.join(display_strs))
                    self.protein_loaded = True
                    return 'Protein file loaded (mds)', gr.update(interactive=self.protein_loaded & self.ligand_laoded), display_update
                else:
                    self.pdb_editor = None
                    self.protein_loaded = False
                    return 'Protein not found in mds file', gr.update(interactive=False), display_update
        self.pdb_editor = None
        self.protein_loaded = False
        return 'Protein file removed', gr.update(interactive=False), display_update
    
    def upload_ligand_file(self, file):
        csv_pth = os.path.join(self.dock_input_dir, 'ligand.csv')
        pdb_pth = os.path.join(self.cache_dir, 'protein_processed.pdb')
        if file:
            new_csv_str = []
            with open(file) as f:
                new_csv_str = [l+',protein_path,protein_sequence' if i == 0 
                               else l+f',{pdb_pth},' 
                               for i, l in enumerate(f.read().strip().splitlines())]
            self.total_ligands = len(new_csv_str) - 1
            new_csv_str = '\n'.join(new_csv_str)
            with open(csv_pth, 'w') as f:
                f.write(new_csv_str)
            self.ligand_laoded = True
            return 'Ligand file loaded', gr.update(interactive=self.protein_loaded & self.ligand_laoded)
        os.remove(csv_pth)
        self.ligand_laoded = False
        return 'Ligand file removed', gr.update(interactive=False)
    
    def check_curr_docking_status(self):
        if self.process is None:
            return 'Not docking', gr.update(interactive=self.protein_loaded & self.ligand_laoded)
        else:
            if self.curr_docking:
                c = 0
                for subdir in os.listdir(self.dock_output_dir):
                    if subdir != 'cache_files' and subdir.endswith('.sdf'):
                        c += 1
                if c == self.total_ligands:
                    self.curr_docking = False
                    return 'Docking done', gr.update(interactive=self.protein_loaded & self.ligand_laoded)
                return f'Docking... ( {c} / {self.total_ligands} )', gr.update(interactive=False)
            else:
                return 'Docking done', gr.update(interactive=self.protein_loaded & self.ligand_laoded)
    
    def check_confidence(self):
        if self.process is not None:
            retrieved = []
            for mol_name in os.listdir(self.dock_output_dir):
                if mol_name != 'cache_files' and mol_name not in self.recorded_names and not mol_name.endswith('.sdf'):
                    subdir = os.path.join(self.dock_output_dir, mol_name)
                    subdir_files = os.listdir(subdir)
                    if subdir_files:
                        mols_aff_map = {}
                        for f in subdir_files:
                            if f.startswith('rank1_'):
                                conf = re.search(confidence_compiled, f).group(1)
                                retrieved.append([mol_name, conf])
                                self.recorded_names.append(mol_name)
                            if '_confidence' in f:
                                with open(os.path.join(subdir, f)) as sdf_f:
                                    mol_block = sdf_f.read()
                                aff = f.rsplit('_confidence', 1)[-1].split('.sdf')[0]
                                mols_aff_map[Chem.MolFromMolBlock(mol_block)] = float(aff)
                        mols_aff_map = dict(sorted(mols_aff_map.items(), reverse=True, key=lambda x: x[1]))
                        target_sdf_file = os.path.join(self.dock_output_dir, mol_name+'_out.sdf')
                        with Chem.SDWriter(target_sdf_file) as writer:
                            for mol, aff in mols_aff_map.items():
                                mol.SetProp('Score', f'ENERGY= {aff}  LOWER_BOUND= 0.000  UPPER_BOUND= 0.000')
                                writer.write(mol)
                        shutil.rmtree(subdir)
            if retrieved:
                with open(self.csv_file, 'a') as f:
                    for name_conf in retrieved:
                        f.write(f'{name_conf[0]},{name_conf[1]}\n')
    
    def _parse_display_string(self, display_res: str):
        display_res_list = display_res.split()
        if not display_res_list:
            return
        for display in display_res_list:
            chain, res = display.split(':')
            self.pdb_editor.update_display(chain, res)
    
    def start_docking(self, display_text):
        csv = os.path.join(self.dock_input_dir, 'ligand.csv')
        pdb_pth = os.path.join(self.cache_dir, 'protein_processed.pdb')
        self._parse_display_string(display_text)
        pdb_text = self.pdb_editor.convert_dict_to_pdb_text()
        with open(pdb_pth, 'w') as f:
            f.write(pdb_text)
        self.curr_docking = True
        with open(self.log_file, 'w') as log_f:
            self.process = subprocess.Popen(['python', 'inference.py',
                                             '--config', 'default_inference_args.yaml',
                                             '--protein_ligand_csv', f'{csv}',
                                             '--out_dir', f'{self.dock_output_dir}'],
                                            stdout=log_f, stderr=log_f)
            
    def read_docked_progress(self):
        with open(self.log_file) as f:
          s = f.read()
        return s
        
    def zip_docked_files(self, check_status, zipped_name):
        if check_status:
            if not zipped_name:
                zipped_name = 'docked_result'
            zipped_file = os.path.join(self.zipped_dir, zipped_name + '.zip')
            all_files = []
            for root, dirs, files in os.walk(self.dock_output_dir):
                for file in files:
                    file_path = os.path.join(root, file)
                    all_files.append(file_path)
            file_cnt = len(all_files)
            yield gr.update(), False, f'Progress ({0:{len(str(file_cnt))}}/{file_cnt})'
            num = 0
            with zipfile.ZipFile(zipped_file, 'w', zipfile.ZIP_LZMA) as zipf:
                for file in all_files:
                    zipf.write(file, os.path.relpath(file, self.dock_output_dir))
                    num += 1
                    yield gr.update(), gr.update(), f'Progress ({num:{len(str(file_cnt))}}/{file_cnt})'
            yield zipped_file, gr.update(), 'Zipping Done'
        else:
            yield gr.update(), gr.update(), gr.update()
        
    def new_dataframe(self):
        with open(self.csv_file) as f:
            text = f.readlines()[1:]
        array = [convert_row_to_string_float(l) for l in text if l.strip()]
        if array:
            return gr.update(value=array)
        else:
            return gr.update(value=None)


interface = DiffDockInterface()
interface.setup_interface()