In [None]:
# build a database from a dataset for nomenclating
# INPUT: a collection of GAIN domain PDBs, their sequences as one large ".fa" file
from gain_classes import GainDomain, GainCollection, Anchors, GPS
import sse_func
import execute
import numpy as np
import glob
import multiprocessing as mp
from subprocess import Popen, PIPE


In [None]:
# The Goal is to create a GainCollection from all GAIN domains here
# The set of proteins needs to be filtered and analyzed statistically before

# Example PDB name:
#  R0LVI9-R0LVI9_ANAPL-PutativeGR125-Anas_platyrhynchos_unrelaxed_rank_1_model_3.pdb
# corresponding STRIDE name:
#  R0LVI9-R0LVI9_ANAPL-PutativeGR125-Anas_platyrhynchos.stride
#pdbs = glob.glob("/home/hildilab/projects/agpcr_nom/*/batch*/*_rank_1_*.pdb")
pdbs = glob.glob("/home/hildilab/projects/agpcr_nom/*output/**/*_rank_1_*.pdb")
print(f"Found {len(pdbs)} best ranked models in target directories.")
#print(len(celsr_pdbs))

In [None]:
stride_folder = f"/home/hildilab/projects/agpcr_nom/all_gps_stride"
stride_bin = "/home/hildilab/lib/stride/stride"
           
def compile_stride_mp_list(pdbs, stride_folder,stride_bin):
    stride_mp_list = []
    
    for pdb in pdbs:
        pdb_name = pdb.split("/")[-1]
        name = pdb_name.split("_unrelaxed_")[0]
        out_file = f"{stride_folder}/{name}.stride"
        arg = [pdb, out_file, stride_bin]
        
        stride_mp_list.append(arg)
        
    return stride_mp_list

def run_stride(arg):
    pdb_file, out_file, stride_bin = arg
    stride_command = f"{stride_bin} {pdb_file} -f{out_file}"
    execute.run_command(stride_command)

def execute_stride_mp(stride_mp_list, n_threads=10):
        stride_pool = mp.Pool(n_threads)
        stride_pool.map(run_stride, stride_mp_list)
        print("Completed mutithreaded creation of STRIDE files!")
        
        #execute.run_stride(pdb, out_file, stride_bin)

In [None]:
#stride_mp_list = compile_stride_mp_list(celsr_pdbs, stride_folder, stride_bin)
#stride_mp_list = compile_stride_mp_list(singlet_pdbs, stride_folder, stride_bin)
#print(len(stride_mp_list))
#[print(arg) for arg in stride_mp_list[:10]]
# MP execution of STRIDE
#execute_stride_mp(stride_mp_list, n_threads=2)

In [None]:
# Eliminate double entries (both in the original run and the added small runs)
# Form the "pdbs" list

stride_files = glob.glob("/home/hildilab/projects/agpcr_nom/all_gps_stride/*")
print(len(stride_files))
accessions = [f.split(".strid")[0].split("/")[-1].split("-")[0] for f in stride_files]
pdb_accessions = np.array([p.split("_unrelaxed_")[0].split("/")[-1].split("-")[0] for p in pdbs])

# Find duplicate in the original pdbs list and indicate them via > is_duplicate = True <
is_duplicate=np.zeros([len(pdbs)], dtype=bool)
sort_pdb_ac = np.sort(pdb_accessions)
duplicate_list = []
for i, pdb in enumerate(sort_pdb_ac):
    if i+1 == len(sort_pdb_ac):
        break
    if pdb == sort_pdb_ac[i+1]:
        duplicate_list.append(pdb)
        multi_indices = np.where(pdb == pdb_accessions)[0]
        is_duplicate[multi_indices[0]] = True

np_pdbs = np.array(pdbs)
singlet_pdbs = np_pdbs[is_duplicate == False] # This is the reduced list with ONLY UNIQUE PDBs
print(f"Reduced the initial set of {len(pdbs)} PDB files down to {len(singlet_pdbs)} files.")

# This is a check routine if there are PDBs in the reduced list which have NOT a STRIDE file
singlet_pdb_accessions = np.array([p.split("_unrelaxed_")[0].split("/")[-1].split("-")[0] for p in singlet_pdbs])

counter = 0
for ac in singlet_pdb_accessions:
    if ac not in accessions:
        print(ac)
    else:
        counter += 1
print(f"Found {counter}/{len(singlet_pdb_accessions)} accessions in the accession list.")

In [None]:
"""stride_folder = f"{pdb_folder}/stride_files"
fasta_file = "/home/hildilab/projects/agpcr_nom/batch_files/batch_60.fa"
# CHECK FILES!
import os
if not os.path.isfile(fasta_file) or os.path.islink(fasta_file):
    print("ERROR: Specify FASTA FILE CONTAINING ALL SEQUENCES OF THE PROTEINS!")
    
    
alignment_file = "/home/hildilab/projects/agpcr_nom/batch_out_test/mafft.fa"
#alignment_file = None

# If there is not a specified base alignment, create one (that might take a while tho.)
if not alignment_file:
    mafft_bin = "mafft"
    mafft_command = f"{mafft_bin} --auto --thread 4 {fasta_file}"
    out_dir = "/".join(fasta_file.split("/")[:-2])
    out_file = f"{out_dir}/mafft.fa"
    execute.run_command(mafft_command, out_file = out_file)
    alignment_file = out_file
gps_minus_one_column = 1209"""

In [None]:
# The GainCollection class needs to be modified to also parse a list of sequences 
# instead of a folder containing one fasta per seq
class FilterCollection:
    ''' 
    A collection of GainDomain objects used for filtering a set of AF2 models
    This is used to condense the dataset towards one containing only GAIN domains

    Attributes
    ----------

    collection : list
        List of GainDomain instances
   
    valid_gps : np.array(bool)
        For each protein, specify if the GPS detection is valid or not
    
    valid_subdomain : np.array(bool)
        For each protein, specify if it has detected subdomains or not
    
    Methods
    ----------
    print_gps(self):
        Prints info about all GainDomain GPS subinstances

    write_all_seq(self, savename):
        Writes all sequences in the Collection to one file

    transform_alignment(self, input_alignment, output_alignment, aln_cutoff):
        Transforms a given alignment with SSE data to a SSE-assigned version of the alignment
        
    write_filtered(self, savename, bool_mask):
        Writes all sequences to File where a boolean mask (i.e. subdomain_criterion, gps_criterion)
        is True at respective position
    '''
    def __init__(self,
                alignment_file, 
                aln_cutoff,
                quality,
                gps_index,
                stride_files,
                sequence_files=None, # modified to object containing all seqs
                sequences=None, # replaces sequence_files
                subdomain_bracket_size=20,
                domain_threshold=20,
                coil_weight=0.00,
                alignment_dict=None): 
        '''
        Constructs the GainCollection objects by initializing one GainDomain instance per given sequence
        
        Parameters
        ----------
        alignment_file:     str, required
            The base dataset alignment file

        aln_cutoff:         int, required
            Integer value of last alignment column

        quality:            list, required
            List of quality valus for each alignment column. Has to have $aln_cutoff items
            By default, would take in the annotated Blosum62 values from the alignment exported from Jalview

        gps_index:          int, required
            Alignment column index of the GPS-1 residue (consensus: Leu)
        
        stride_files:       list, required
            A list of stride files corresponding to the sequences contained within.
        
        sequence_files:     list, optional
            A list of sequence files to be read as the collection - specify either this
            or sequences as an object instead of files for sequences
        
        sequences:          object, optional
            A list of (sequence_name, sequence) tuples containing all sequences. Can be specified
            instead of sequence_files

        subdomain_bracket_size: int, optional
            Smoothing window size for the signal convolve function. Default = 20.

        domain_threshold:   int, optional
            Minimum size of a helical segment to be considered candidate for Subdomain A of GAIN domain. Default = 20.

        coil_weight:        float, optional
            Weight assigned to unordered residues during Subdomain detection. Enables decay of helical signal
            default = 0. Recommended values < +0.2 for decay

        Returns
        ----------
        None
        '''
        # Initialize collection (containing all GainDomain instances) and the anchor-histogram
        if sequence_files:
            # Compile all sequence files to a sequences object
            sequences = np.empty([len(sequence_files)])
            for i, seq_file in enumerate(sequence_files):
                name, seq = sse_func.read_seq(seq_file, return_name=True)
                sequences[i] = (name, seq)
        elif (sequences is not None):
            print(f"Found sequences object.")
        else: 
            print(f"ERROR: no sequence_files or sequences parameter found. Aborting compilation.")
            return None
        self.collection = np.empty([len(sequences)], dtype=object)
        
        valid_gps = np.zeros([len(sequences)], dtype=bool)
        valid_subdomain = np.zeros([len(sequences)], dtype=bool)
        #anchor_hist = np.zeros([aln_cutoff])#
        # Create a GainDomain instance for each sequence file contained in the list
        for i,seq_tup in enumerate(sequences):
            # updated GainDomain class
            name, sequence = seq_tup
            explicit_stride = [stride for stride in stride_files if name.split("-")[0] in stride]
            if len(explicit_stride) == 0:
                print(f"Stride file not found for {name}")
                continue
            newGain = GainDomain(alignment_file, 
                                  aln_cutoff,
                                  quality,
                                  name = name,
                                  sequence = sequence,
                                  gps_index = gps_index, 
                                  subdomain_bracket_size = subdomain_bracket_size,
                                  domain_threshold = domain_threshold,
                                  coil_weight = coil_weight,
                                  explicit_stride_file = explicit_stride[0],
                                  without_anchors = True,
                                  skip_naming = True,
                                  alignment_dict = alignment_dict)

            # Check if the object staisfies minimum criteria
            if newGain.isValid: 
                  
                self.collection[i] = newGain
                  
                if newGain.hasSubdomain:
                    valid_subdomain[i] = True
                if newGain.GPS.isConsensus:
                    valid_gps[i] = True
        
        self.valid_subdomain = valid_subdomain
        self.valid_gps = valid_gps
        print(f"Completed Collection initialitazion of {len(sequences)} sequences.\n"
             f"{np.count_nonzero(self.collection)} valid proteins were found.\n"
             f"{np.count_nonzero(self.valid_subdomain)} of which have detected Subdomains.\n"
             f"{np.count_nonzero(self.valid_gps)} of which have detected consensus GPS motifs.\n")

    def print_gps(self):
        '''
        Prints information about the GPS of each GAIN domain.

        Parameters:
            None
        Returns:
            None
        '''
        for i, gain in enumerate(self.collection):
            try:
                gain.GPS.info()
            except:
                print(f"No GPS data for {gain.name}. Likely not a GAIN Domain!")

    def write_all_seq(self, savename):
        '''
        Write all GAIN sequences of the GainCollection into one fasta file.

        Parameters
        ----------
        savename: str, required
            Output name of the fasta file.

        Returns
        ----------
        None
        '''
        with open(savename, 'w') as f:
            for gain in self.collection:
                f.write(f">{gain.name[:-3]}\n{''.join(gain.sequence)}\n")
    
    def write_filtered(self, savename, bool_mask=None, write_mode='w'):
                  
        '''
        Internal function for writing filtered sequences to file.
        Takes the Gain.sequence np.array type to write the truncated versions.
        
        Parameters
        ----------
        savename: str, required
            Output name of the fasta file.
        bool_mask: list/array, required
            A mask of len(self.collection) where a boolean defines whether to write the
            sequence to file or not
        Returns
        ----------
        None
        '''
        with open(savename, write_mode) as f:
            print(f"writing filtered to {savename}")
            for i, gain in enumerate(self.collection):
                if gain is not None and bool_mask[i] == True:
                    f.write(f">{gain.name.replace('.fa','')}\n{''.join(gain.sequence)}\n")

    def transform_alignment(self, input_alignment, output_alignment, aln_cutoff):
        ''' 
        Transform any input alignment containing all sequences in the GainCollection 
        into one where each residue is replaced with the respective 
        Secondary Structure from the STRIDE files

        Parameters
        ----------
        input_alignment: str, required
            Input alignment file
        output_alignment: str, required
            Output alignment file
        aln_cutoff: int, required
            Last alignment column to be read from the Input Alignment

        Returns
        ---------
        None
        '''
        initial_dict = sse_func.read_alignment(input_alignment, aln_cutoff)
        out_dict = {}
        for gain in self.collection:
            sse_alignment_row = np.full([aln_cutoff], fill_value='-', dtype='<U1')
            mapper = sse_func.get_indices(gain.name, gain.sequence, input_alignment, aln_cutoff)
            for index, resid in enumerate(gain.sse_sequence):
                sse_alignment_row[mapper[index]] = resid
            out_dict[gain.name[:-3]] = sse_alignment_row

        # Write to file
        with open(output_alignment, "w") as f:
            for key in out_dict.keys():
                f.write(f">{key}\n{''.join(out_dict[key])}\n")
        print(f"Done transforming alignment {input_alignment} to {output_alignment} with SSE data.")
            

In [None]:
#alignment_file = "/home/hildilab/projects/agpcr_nom/batch_out_test/big_mafft.fa"
#quality_file = "/home/hildilab/projects/agpcr_nom/batch_out_test/big_mafft.jal"
alignment_file = "/home/hildilab/projects/GPS_massif/uniprot_query/trunc_celsr.mafft.fa"
quality_file = "/home/hildilab/projects/GPS_massif/uniprot_query/trunc_celsr.mafft.jal"
fasta_file = "/home/hildilab/projects/GPS_massif/uniprot_query/all_celsr_trunc.fa"
#stride_folder = f"{pdb_folder}/stride_files"
quality = sse_func.read_quality(quality_file)
gps_minus_one = 4966 #19258
aln_cutoff = 4990 #19822
sequences = sse_func.read_multi_seq(fasta_file)
print(len(sequences))
stride_files = glob.glob("/home/hildilab/projects/agpcr_nom/all_gps_stride/*")
print(len(stride_files))

In [None]:
"""filterCollection = FilterCollection(alignment_file,
                                   aln_cutoff = 19822,
                                   quality = quality,
                                   gps_index = gps_minus_one,
                                   stride_files = stride_files,
                                   sequences = sequences)"""

In [None]:
#filterCollection.write_filtered(savename="test.fa", bool_mask = filterCollection.valid_gps, write_mode='w')

In [None]:
def batch_filter_seqs(arg_item):
                   # [sequences,      # A number of sequences as tuple instances
                   #  stride_folder,  # A folder containing ALL stride files
                   #   output_prefix,  # A prefix for individual file identification
                   #   alignment_file, # The big (initial) alignment file
                   #   quality,        # The corr. parsed quality for BLOSUM62 score
                   #   aln_cutoff,     # the left-most column (19822 for big_mafft.fa)
                   #   gps_minus_one,  # The column index of GPS-1 (zero-indexed! 19258 big_mafft)
                   #   ]
    sequences, stride_folder, output_prefix, alignment_file, quality, aln_cutoff, gps_minus_one, alignment_dict = arg_item
    # Parallelizable version of filtering sequences and models via FilterCollection
    # This should create separate files for each valid, fragment and no-SD group
    # These files should then be grouped together
    # The batch size is arbitrary and is considered the number of sequences passed
    
    # Output: Profiles; 4 Text files (valid, fragment, invalidGPS, invalid)
    filteredBatch = FilterCollection(alignment_file,
                                   aln_cutoff = aln_cutoff,
                                   quality = quality,
                                   gps_index = gps_minus_one,
                                   stride_files = stride_files,
                                   sequences = sequences,
                                   alignment_dict = alignment_dict)
    outpath = "/home/hildilab/projects/agpcr_nom/all_gps_profiles_001"
    hel_path = "/home/hildilab/projects/agpcr_nom/all_gps_hels_001"
    
    for Gain in filteredBatch.collection:
        if Gain:
            Gain.plot_profile(outdir=outpath, noshow=True)
            if Gain.hasSubdomain:
                Gain.plot_helicality(coil_weight=0.01, savename=f'{outpath}/{Gain.name}.hel.png', debug=False, noshow=True)
        
    suffixes = ["gain", 
                "fragments", 
                "noncons_gps", 
                "invalid"]
    masks = [np.logical_and(filteredBatch.valid_gps, filteredBatch.valid_subdomain),
              np.logical_not(filteredBatch.valid_subdomain),
              np.logical_not(filteredBatch.valid_gps),
              np.logical_not(np.logical_and(filteredBatch.valid_gps, filteredBatch.valid_subdomain))]
    # write four separate files, matching each criterion
    for k in range(4):
        filteredBatch.write_filtered(savename=f"{outpath}/{output_prefix}_{suffixes[k]}.fa", 
                                     bool_mask = masks[k],
                                     write_mode = 'w')
    del filteredBatch
    return None 

In [None]:
def run_mp_collection(arg_list, n_threads=10):
    pool = mp.Pool(n_threads)
    pool.map(batch_filter_seqs, arg_list)
    print("Completed mutithreaded filtering.")

def construct_arg_list(batch_sequence_files, 
                       output_folder,
                       stride_folder, 
                       quality, 
                       alignment_file, 
                       aln_cutoff, 
                       gps_minus_one,
                       alignment_dict = None):
    """ each item looks like this:
        sequences, \ 
        stride_folder, \
        output_prefix, \
        alignment_file, \
        quality, \
        aln_cutoff, \
        gps_minus_one = arg_item"""
    # static : stride_folder, quality, alignment_file, aln_cutoff, gps_minus_one
    # flexible : sequences, output_prefix
    arg_list = []
    #
    for idx, sequence_file in enumerate(batch_sequence_files):
        
        index_string = str(idx)
        sequences = sse_func.read_multi_seq(sequence_file)
        output_prefix = f"{output_folder}_{index_string.zfill(3)}"
        item = [sequences, 
                stride_folder, 
                output_prefix, 
                alignment_file, 
                quality, 
                aln_cutoff, 
                gps_minus_one,
                alignment_dict]
        
        arg_list.append(item)
    
    print(f"[NOTE] : Compiled list of arguments for multithreaded filtering"
          f" containing {len(arg_list)} items.")
    return arg_list

In [None]:
# 
batch_sequence_files = glob.glob("/home/hildilab/projects/agpcr_nom/*output/batch_*.fa")
print(len(batch_sequence_files))
output_folder = "app_gain_domains_001"

alignment_file = "/home/hildilab/projects/agpcr_nom/appended_big_mafft.fa" # This is a combined alignment of ALL sequences in ALL queries!
quality_file = "/home/hildilab/projects/agpcr_nom/appended_big_mafft.jal"  # ^ corresponding quality file
stride_folder = "/home/hildilab/projects/agpcr_nom/all_gps_stride" 
quality = sse_func.read_quality(quality_file)
gps_minus_one = 21160  # 19258
aln_cutoff = 21813 # 19822
stride_files = glob.glob("/home/hildilab/projects/agpcr_nom/all_gps_stride/*")
alignment_dict = sse_func.read_alignment(alignment_file, aln_cutoff)
print(len(stride_files))
print(len(batch_sequence_files))
#print(quality)
#sequences = sse_func.read_multi_seq(fasta_file)

arg_list = construct_arg_list(batch_sequence_files, 
                       output_folder,
                       stride_folder, 
                       quality, 
                       alignment_file, 
                       aln_cutoff, 
                       gps_minus_one,
                       alignment_dict)


In [None]:
#for k in range(0, 10):#len(arg_list)+1):
#    batch_filter_seqs(arg_list[k])
run_mp_collection(arg_list[400:], n_threads=16)
#for arg in arg_list[-5:]: print(arg[0])

In [None]:

def compile_fastas(prefix, out_prefix):
    # Compiles the fasta files together to construct one large file containing the sequences
    # satisfying each criterion in the 2x2 matrix
    # we want to have the GAIN sequence only that is output by the write_filtered() func.
    
    # Gather all files:
    suffixes = ["gain", 
                "fragments", 
                "noncons", 
                "invalid"]
    all_files = np.asarray(glob.glob(f"{prefix}*fa"))
    print(len(all_files))
    for suffix in suffixes:
        sub_list = sorted([f for f in all_files if suffix in f.split("_")[-1]])
        print(f"Sublist constructed for {suffix = } containing {len(sub_list)} files.")
        with open(f"{out_prefix}_{suffix}.fa", "w") as all_file:
            all_seqs = []
            for file in sub_list:
                seqs = sse_func.read_multi_seq(file)
                for j in seqs:
                    if j in all_seqs:
                        print(j[0], "doublet")
                        continue
                    all_seqs.append(j)
                    all_file.write(f">{j[0]}\n{j[1]}\n")

compile_fastas("/home/hildilab/projects/agpcr_nom/all_gps_profiles_001/app_gain_domains",
              out_prefix = "/home/hildilab/projects/agpcr_nom/app_gain_001")

In [None]:
# Get the full profile for D1 : Q6QNK2.
# Compare the 0.01 sequences to the 0.00 sequences:
f1 = "../app_gain_001_gain.fa"
f0 = "../app_gain_gain.fa"
dict1 = sse_func.read_alignment(f1, -1)
dict0 = sse_func.read_alignment(f0, -1)
for k in dict0.keys():
    if k not in dict1.keys():
        print(k)