# Analyze MicroMiner hits on the whole PDB
Collects all hits of MicroMiner for the PDB and stores them in a SQLite DB for memory efficient analysis. 



In [1]:
import csv
import sqlite3
import sys
from collections import OrderedDict
from pathlib import Path


root_dir = Path('/local/sieg/projekte/microminer_evaluation')
sys.path.insert(0, str(root_dir.resolve()))
import helper
from helper.constants import one_2_three_dict, three_2_one_dict, MM_QUERY_NAME, MM_QUERY_CHAIN, MM_QUERY_AA, MM_QUERY_POS, \
                             MM_HIT_NAME, MM_HIT_CHAIN, MM_HIT_AA, MM_HIT_POS

In [2]:
db_name_monomer = "/local/sieg/mm_pdb_single_mutations_monomer.db"
csv_monomer = '/local/sieg/pdb_all_monomer.tsv'
db_name_ppi = "/local/sieg/mm_pdb_single_mutations_ppi.db"
csv_ppi = '/local/sieg/pdb_all_ppi.tsv'

In [3]:
class CSVWriterToDatabase:
    """Converts a MicroMiner output CSV file to a single table in a database"""
    
    RESULT_STATISTICS_FIELDS = OrderedDict([(MM_QUERY_NAME, 'TEXT'),
                                            (MM_QUERY_CHAIN, 'TEXT'),
                                            (MM_QUERY_AA, 'TEXT'), 
                                            (MM_QUERY_POS, 'TEXT'),
                                            (MM_HIT_NAME, 'TEXT'),
                                            (MM_HIT_CHAIN, 'TEXT'),
                                            (MM_HIT_AA, 'TEXT'),
                                            (MM_HIT_POS, 'TEXT'),
                                            ('siteIdentity', 'REAL'),
                                            ('siteBackBoneRMSD', 'REAL'),
                                            ('siteAllAtomRMSD', 'REAL'),
                                            ('nofSiteResidues', 'REAL'),
                                            ('alignmentLDDT', 'REAL'),
                                            ('fullSeqId', 'REAL')])
    
    # each hit is defined uniquely by the tuple of name, chain, aino acid type and position of 
    # query and hit structure
    RESULT_STATISTICS_PRIMARY_FIELDS = (MM_QUERY_NAME, MM_QUERY_CHAIN, MM_QUERY_AA, MM_QUERY_POS,
                                        MM_HIT_NAME, MM_HIT_CHAIN, MM_HIT_AA, MM_HIT_POS)
    
    # Inverse of the primary key to identify symmetric matches
    RESULT_STATISTICS_INVERSE_KEY = [MM_HIT_NAME, MM_HIT_CHAIN, MM_HIT_AA, MM_HIT_POS, 
                                     MM_QUERY_NAME, MM_QUERY_CHAIN, MM_QUERY_AA, MM_QUERY_POS]
    
    
    def __init__(self, db_type, db_name, ignore_nonstandard_aa, delimiter=','):
        self.db_type = db_type
        self.db_name = db_name
        self.delimiter = delimiter
        self.ignore_nonstandard_aa = ignore_nonstandard_aa
        self.table_name = 'result_statistics'

        if self.db_type == "sqlite":
            self.conn = sqlite3.connect(self.db_name)
            self.cur = self.conn.cursor()
        elif self.db_type == "postgres":
            raise NotImplementedError()
#             self.conn = psycopg2.connect(database=self.db_name)
#             self.cur = self.conn.cursor()
        else:
            raise ValueError("Invalid database type")
            
        # init table
        self.create_resultstatistics_table()
        
    def __del__(self):
        self.cur.close()
        self.conn.close()
            
    def create_resultstatistics_table(self):
        # Create the result_statistics table
        self.cur.execute(f'''CREATE TABLE IF NOT EXISTS {self.table_name}
                           ({', '.join(f"{k} {v}" for k, v in 
                             CSVWriterToDatabase.RESULT_STATISTICS_FIELDS.items())},
                             PRIMARY KEY ({', '.join(CSVWriterToDatabase.RESULT_STATISTICS_PRIMARY_FIELDS)})
                           )''')
        self.conn.commit()

    def write_csv_to_database(self, file_path):
        with open(file_path, "r") as file:
            reader = csv.reader(file, delimiter=self.delimiter)

            header = next(reader)

            col_indices_dict = OrderedDict([(col, header.index(col)) for col in CSVWriterToDatabase.RESULT_STATISTICS_FIELDS.keys()])

            placeholder = '?'

            # insert row into the database. While inserting we ignore duplicates (implicitly through the primary key) and 
            # ignore rows where the inverse of the row is already present in the database (the SELECT-WHERE-statement). 
            query = f"""INSERT INTO {self.table_name} ({', '.join(CSVWriterToDatabase.RESULT_STATISTICS_FIELDS.keys())}) 
                            SELECT {','.join([placeholder] * len(col_indices_dict))} WHERE NOT EXISTS (
                            SELECT 1 FROM {self.table_name} 
                            WHERE {' AND '.join(f'{field} = {placeholder}'
                                                for field in CSVWriterToDatabase.RESULT_STATISTICS_PRIMARY_FIELDS)})"""
#             print(query)
            counter_skipped_nonstandard_aa = 0
            for row in reader:

                if self.ignore_nonstandard_aa and any(row[col_indices_dict[field]] not in three_2_one_dict for field in [MM_QUERY_AA, MM_HIT_AA]):
                    counter_skipped_nonstandard_aa += 1
                    continue;

                row_values = [row[i] for i in col_indices_dict.values()]
                row_values.extend([row[col_indices_dict[field]] for field in CSVWriterToDatabase.RESULT_STATISTICS_INVERSE_KEY])
#                 print(row_values)
                self.cur.execute(query, row_values)
    
            print(f'skipped {counter_skipped_nonstandard_aa} non-standard AA hit rows')

        self.conn.commit()
        

class DatabaseReader:
    """Reads from a database representing a MicroMiner results table"""
    
    def __init__(self, db_type, db_name):
        self.db_type = db_type
        self.db_name = db_name
        self.nof_rows = None
        self.table_name = 'result_statistics'

        if self.db_type == "sqlite":
            self.conn = sqlite3.connect(self.db_name)
            self.cur = self.conn.cursor()
        elif self.db_type == "postgres":
            raise NotImplementedError()
#             self.conn = psycopg2.connect(database=self.db_name)
#             self.cur = self.conn.cursor()
        else:
            raise ValueError("Invalid database type")
            
    def query(self, sql_str):
        self.cur.execute(sql_str)
        rows = self.cur.fetchmany()
        while len(rows) > 0:
            yield rows
            rows = self.cur.fetchmany()
            
    def get_header(self):
        self.cur.execute("SELECT name FROM pragma_table_info('result_statistics') ORDER BY cid")
        return [elem[0] for elem in self.cur.fetchall()]
        
    def get_nof_rows(self):
        if self.nof_rows is None:
            self.nof_rows = self.cur.execute("SELECT COUNT(*) FROM result_statistics").fetchone()[0]
        return self.nof_rows
            
    def get_mean(self, col):
        self.cur.execute(f"SELECT AVG({col}) FROM {self.table_name}")
        avg_value = self.cur.fetchone()[0]
        return avg_value
    
    def get_median(self, col):
        self.cur.execute(f'''SELECT AVG({col})
                             FROM (SELECT {col}
                                   FROM {self.table_name}
                                   ORDER BY {col}
                                   LIMIT 2 - (SELECT COUNT(*) FROM {self.table_name}) % 2    -- odd 1, even 2
                                   OFFSET (SELECT (COUNT(*) - 1) / 2
                                           FROM {self.table_name}))''')
        median_value = self.cur.fetchone()[0]
        return median_value
              

In [4]:
def sanity_check_bidirectional_hits(reader):
    # sanity check the database for symmetric duplicates
    reader.cur.execute('select a.queryName,a.queryChain,a.queryAA,a.queryPos,a.hitName,a.hitChain,a.hitAA,a.hitPos,b.queryName,'
    'b.queryChain,b.queryAA,b.queryPos,b.hitName,b.hitChain,b.hitAA,b.hitPos from result_statistics as a,result_statistics as b'
    ' where a.queryName = b.hitName AND a.queryChain = b.hitChain AND a.queryAA = b.hitAA '
    'AND a.queryPos = b.hitPos AND a.hitName = b.queryName AND a.hitChain = b.queryChain AND a.hitAA = b.queryAA AND a.hitPos = b.queryPos')
    assert reader.cur.fetchone() is None
    
def count_non_standard_aa(reader):
    # Count non-standard amino acid hits
    query = f"""SELECT COUNT(*) FROM result_statistics WHERE queryAA NOT IN ("{'", "'.join(three_2_one_dict.keys())}")
     OR hitAA NOT IN ("{'", "'.join(three_2_one_dict.keys())}")"""
    # print(query)
    reader.cur.execute(query)
    return reader.cur.fetchone()

### Handle monomer mode results

In [5]:
writer = CSVWriterToDatabase(db_type="sqlite", ignore_nonstandard_aa=True, db_name=db_name_monomer, delimiter='\t')
writer.write_csv_to_database(csv_monomer)

skipped 350968 non-standard AA hit rows


In [6]:
reader_monomer = DatabaseReader(db_type="sqlite", db_name=db_name_monomer)
sanity_check_bidirectional_hits(reader_monomer)
assert count_non_standard_aa(reader_monomer)[0] == 0

In [7]:
# filter results down
query = 'SELECT *, MAX(fullSeqId) from result_statistics WHERE fullSeqId > 0.4 GROUP BY queryName, queryAA, queryChain, queryPos, hitAA'
with open('/local/sieg/filtered_single_mutations_pdb_monomer.tsv', 'w') as csvfile:
    csv_writer = csv.writer(csvfile, delimiter='\t')
    csv_writer.writerow(reader_monomer.get_header())
    for rows in reader_monomer.query(query):
        csv_writer.writerows(rows)

### handle ppi mode results

In [8]:
writer = CSVWriterToDatabase(db_type="sqlite", ignore_nonstandard_aa=True, db_name=db_name_ppi, delimiter='\t')
writer.write_csv_to_database(csv_ppi)

skipped 114670 non-standard AA hit rows


In [9]:
reader_ppi = DatabaseReader(db_type="sqlite", db_name=db_name_ppi)
sanity_check_bidirectional_hits(reader_ppi)
assert count_non_standard_aa(reader_ppi)[0] == 0

In [10]:
# filter results down
query = 'SELECT *, MAX(fullSeqId) from result_statistics WHERE fullSeqId > 0.4 GROUP BY queryName, queryAA, queryChain, queryPos, hitAA'
with open('/local/sieg/filtered_single_mutations_pdb_ppi.tsv', 'w') as csvfile:
    csv_writer = csv.writer(csvfile, delimiter='\t')
    csv_writer.writerow(reader_ppi.get_header())
    for rows in reader_ppi.query(query):
        csv_writer.writerows(rows)

### handle mutations to non-standard residues

In [11]:
def extract_mutations_non_standard_aas(file_path):
    result_rows = []
    with open(file_path, "r") as file:
        reader = csv.reader(file, delimiter='\t')
        header = next(reader)
        col_indices_dict = OrderedDict([(col, header.index(col)) for col in [MM_QUERY_AA, MM_HIT_AA]])
        for row in reader:
            if any(row[col_indices_dict[field]] not in three_2_one_dict for field in [MM_QUERY_AA, MM_HIT_AA]):
                result_rows.append(row)
    return header, result_rows

with open('/local/sieg/single_mutations_pdb_monomer_non_standard_aa.tsv', 'w') as csvfile:
    csv_writer = csv.writer(csvfile, delimiter='\t')
    header, rows = extract_mutations_non_standard_aas(csv_monomer)
    csv_writer.writerow(header)
    csv_writer.writerows(rows)        
    
with open('/local/sieg/single_mutations_pdb_ppi_non_standard_aa.tsv', 'w') as csvfile:
    csv_writer = csv.writer(csvfile, delimiter='\t')
    header, rows = extract_mutations_non_standard_aas(csv_ppi)
    csv_writer.writerow(header)
    csv_writer.writerows(rows)       