In [25]:
import numpy as np
from collections import defaultdict
from collections import OrderedDict
from Bio import SeqIO
from Bio.Seq import Seq
from Bio import SeqRecord
import random
from scipy.spatial.distance import pdist

FILE_PATH = "1uaz.pdb"
IORF_PATH = "Scer_NCBI_iORF.faa"
random.seed(25032024)

class SuperOD(OrderedDict):
    def __getitem__(self, key):
        if isinstance(key, slice):
            
            keys = list(self.keys())[key]
            return SuperOD((k, self[k]) for k in keys)
        else:
            
            return OrderedDict.__getitem__(self, key)




def compute_pairwise_distances(pdb_struct):
    # Extract coordinates from pdb_struct
    coords = []
    for chain in pdb_struct["full"].values():
        for residue in chain.values():
            coords.append(residue["coord"])

    # Return mean pairwise distance
    return np.mean(pdist(coords))

def size_picker_v2(fasta_file = IORF_PATH, min_length = 0, max_length = 1000):

    sizes = []

    for record in SeqIO.parse(fasta_file, "fasta"):
        length = len(record.seq)
        if min_length <= length <= max_length:

            sizes.append(length)
    
    return random.choice(sizes)


def pdb_struct_to_fasta(pdb_struct):

    aa_dict = {
                    'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E',
                    'PHE': 'F', 'GLY': 'G', 'HIS': 'H', 'ILE': 'I',
                    'LYS': 'K', 'LEU': 'L', 'MET': 'M', 'ASN': 'N',
                    'PRO': 'P', 'GLN': 'Q', 'ARG': 'R', 'SER': 'S',
                    'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y'
                }
    fasta = ""

    for res_number in pdb_struct["CA"]:

        res_name = pdb_struct["CA"][res_number]["res_name"]

        fasta += aa_dict[res_name]

    record = SeqRecord.SeqRecord(Seq(fasta), id="1uaz", description="1uaz")
    filename = f"{pdb_struct['protein_name']}.fasta"
    SeqIO.write(record, filename, "fasta")

    return 0

def read_pdb(file_path):
    
    aa_dict = {
                    'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E',
                    'PHE': 'F', 'GLY': 'G', 'HIS': 'H', 'ILE': 'I',
                    'LYS': 'K', 'LEU': 'L', 'MET': 'M', 'ASN': 'N',
                    'PRO': 'P', 'GLN': 'Q', 'ARG': 'R', 'SER': 'S',
                    'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y'
                }

    pdb_struct = {}

    pdb_struct["protein_name"] = file_path.split(".")[0]

    pdb_struct["full"] = defaultdict(SuperOD)
    pdb_struct["CA"] = defaultdict(SuperOD)
    pdb_struct["membrane_coord"] = []

    array = []

    with open(file_path,"r") as f:

        line = f.readline()

        while line:

            line = line.split()

            if line[0] == "ATOM":

                # Line format : 
                # ATOM      2  CA  MET A   1      24.767  -2.102  13.513  1.00  0.00      A1A9 C  

                x = float(line[6])
                y = float(line[7])
                z = float(line[8])

                atom_name = line[2]
                atom_number = line[1]
                res_name = line[3]
                res_number = line[5]
                
                chain_id = line[4]

                if chain_id not in pdb_struct["full"]:

                    pdb_struct["full"][chain_id] = SuperOD()
                    pdb_struct["CA"][chain_id] = SuperOD()


                else:

                    pdb_struct["full"][chain_id][res_number] = {

                        "coord" : [x,y,z],
                        "atom_name" : atom_name,
                        "res_name" : aa_dict[res_name],
                        "res_number" : res_number,
                        "atom_number" : atom_number
                    }

                
                    if line[2] == "CA":


                        pdb_struct["CA"][chain_id][res_number] = {

                            "coord" : [x,y,z],
                            "res_name" : aa_dict[res_name],
                            "res_number" : res_number,
                        }

            elif line[0] == "HETATM" and "DUM" in line:

                # Line format :
                # HETATM  643  O   DUM   643     -24.000  -6.000  14.200   
            
                x = float(line[5])
                y = float(line[6])    
                z = float(line[7])

                array.append([x,y,z])

            line = f.readline()

    pdb_struct["membrane_coord"] = np.array(array)
    pdb_struct["protein_length"] = { }

    # For each chain, compute the length.
    # The length is the highest value of the residue numbers
    for chain_id in pdb_struct["CA"]:
       
        pdb_struct["protein_length"][chain_id] = max([int(res_number) for res_number in pdb_struct["CA"][chain_id].keys()])
        print(f"Chain {chain_id} has length {pdb_struct['protein_length'][chain_id]}")

    return pdb_struct

def return_binaries(pdb_struct: dict, lower_margin=0, margin=5):
    
    min_z_membrane = np.min(pdb_struct["membrane_coord"][:, 2])
    max_z_membrane = np.max(pdb_struct["membrane_coord"][:, 2])

    in_membrane_binaries = { chain_id: "" for chain_id in pdb_struct["CA"].keys() }
    in_margin_binaries = { chain_id: "" for chain_id in pdb_struct["CA"].keys() }

    # Assuming that the residues start at 1
    last_res_number = 0
    for chain_id, residues in pdb_struct["CA"].items():
        for res_number, data in residues.items():

            z = data["coord"][2]

            # Sometimes, residues are missing from the 3D structure and we don't have their coordinates
            # We can monitor this by checking if the current residue number is not the previous residue number + 1
            # If some residues are missing, fill the binary string with zeros corresponding to the number of missing residues
            # By default we assume that the missing residues are not of interest and are not in the membrane

            if int(res_number) != last_res_number + 1:

                # Fill in the gaps with zeros
                for i in range(int(res_number) - last_res_number - 1):
                    in_membrane_binaries[chain_id] += "0"
                    in_margin_binaries[chain_id] += "0"
                
                # Then compute the binary string for the current residue
                in_membrane = "1" if min_z_membrane <= z <= max_z_membrane + lower_margin else "0"
                in_margin = "1" if min_z_membrane - margin <= z <= max_z_membrane + margin else "0"

            else:

                in_membrane = "1" if min_z_membrane <= z <= max_z_membrane + lower_margin else "0"
                in_margin = "1" if min_z_membrane - margin <= z <= max_z_membrane + margin else "0"

            in_membrane_binaries[chain_id] += in_membrane
            in_margin_binaries[chain_id] += in_margin

            last_res_number = int(res_number)

    # Now before returning, we can assume that the sequence 101 

    return in_membrane_binaries, in_margin_binaries

def extract_tm_segments_indices(binary_dict : dict):

    segment_indices = {chain_id : [] for chain_id in binary_dict.keys() }

    start_index = None

    for chain_id in binary_dict:
        for i, bit in enumerate(binary_dict[chain_id]):
            if bit == "1":
                if start_index is None:
                    start_index = i+1
            else:
                if start_index is not None:
                    length = i-start_index+1
                    if length >= 15: # minimum length of a TM segment, although 20 is the length of a typical alpha helical TM segment
                        
                        # python indices are 0-based, so we add 1 to match 
                        # the 1-based residue numbering in the PDB file
                        segment_indices[chain_id].append((start_index, i+1, length))

                    # Wether the segment is long enough or not, we reset the start_index
                    start_index = None

        # If we ended in the middle of a TM segment, we add it to the list
        if start_index is not None:
            length = len(binary_dict[chain_id]) - start_index + 1
            if length >= 15:
                segment_indices[chain_id].append((start_index, len(binary_dict[chain_id]), length))

    return segment_indices

def elongate_tm_segments(tm_indices : dict, pdb_struct : dict, min_length=20, max_length=70):
    """
    This function takes a list of tuples containing the start and end indices of putative transmembrane (tm) segments
    Extracted for the same multiple-fragments transmembrane protein.
    For example, GPCR proteins have 7 transmembrane segments, they will end up in a list of 7 tuples.

    For each tm segment, the function will elongate the segment to a random size drawn from a given size distribution,
    given by the size_picker_v2 function.

    The function will elongate the segment only if the size of the segment is smaller than the size drawn from the distribution.
    The goal here is to "draw" from the parts of the sequence that are not transmembrane segments, and elongate the tm segments.
    One main goal is to avoid drawing twice from the same region to elongate two tm segments that are adjacent to each other.

    Input:

    tm_indices[chain_id] : list of tuples 
                # [ (12,26,15), (45, 60, 16), (80, 100, 21) ...]
                # [ (start, end, length), ... ]

    min_length : int
                # minimum length of the elongated segment

    max_length : int
                # maximum length of the elongated segment
    """

    for chain_id in tm_indices:

        
        protein_length = pdb_struct["protein_length"][chain_id]

        ##### Treat first TM Segment separately ##### 


        desired_length = size_picker_v2(min_length=min_length, max_length=max_length)

        
        # First TM Segment
        start_current = tm_indices[chain_id][0][0]
        end_current = tm_indices[chain_id][0][1]
        length_current = tm_indices[chain_id][0][2]


        if desired_length > length_current:


            # Second TM Segment
            start_next = tm_indices[chain_id][1][0]

            elongation_left_to_do = desired_length - length_current


            downstream = random.randint(0, elongation_left_to_do)

            
            lefts = None

            # The new end of this tm segment should not exceed the start of the next tm segment
            if downstream + end_current > start_next:

                new_end_coordinates = start_next - 1

                lefts = downstream - ( start_next - end_current )



            else:

                new_end_coordinates = end_current + downstream


            upstream = elongation_left_to_do - downstream

            

            if lefts:

                upstream += lefts



            if start_current - upstream < 1:

                new_start_coordinates = 1



            else:

                new_start_coordinates = start_current - upstream

            tm_indices[chain_id][0] = (new_start_coordinates, new_end_coordinates, new_end_coordinates - new_start_coordinates)


        ##### Treat from the second TM Segment to the penultimate one ( n-1 ) #####

        for i in range(1, len(tm_indices[chain_id]) - 1):

            # Target size that the current tm should reach
            desired_length = size_picker_v2(min_length=min_length, max_length=max_length)

            # ith TM Segment
            start_current = tm_indices[chain_id][i][0]
            end_current = tm_indices[chain_id][i][1]
            length_current = tm_indices[chain_id][i][2]

            # check before anything else to save computation time
            if desired_length <= length_current:

                # If there is no elongation to do, we skip to the next segment
                # and the coordinates of the ith segment are not modified
                continue
            
            # (i+1)th TM Segment
            start_next = tm_indices[chain_id][i+1][0]


            # (i-1)th TM Segment
            end_previous = tm_indices[chain_id][i-1][1]
            
            # Compute the number of residues that are required to elongate the current segment
            elongation_left_to_do = desired_length - length_current


            # Randomly choose the number of residues to elongate downstream ( toward the C-terminal )
            downstream = random.randint(0, elongation_left_to_do)

            lefts = None

            # The new end of this tm segment should not exceed the start of the next tm segment
            if downstream + end_current > start_next:

                # Hence take everyting that is between the end of the current tm segment and the start of the next one
                new_end_coordinates = start_next - 1

                # What is " left " from downstream that could not be taken cause of the next tm ? 
                lefts = downstream - (start_next - end_current)

            else:

                new_end_coordinates = end_current + downstream

            ## If there is elongation that was not taken from downstream, add it to the upstream
            upstream = elongation_left_to_do - downstream
            if lefts:

                upstream += lefts


            # The new start of this tm segment should not be lower than the end of the previous tm segment
            if start_current - upstream < end_previous:

                new_start_coordinates = end_previous + 1 

            else:

                new_start_coordinates = start_current - upstream


            tm_indices[chain_id][i] = (new_start_coordinates, new_end_coordinates, new_end_coordinates - new_start_coordinates)

            


        ##### Treat the last TM Segment #####

        # Target size that the current tm should reach
        desired_length = size_picker_v2(min_length=min_length, max_length=max_length)


        # Last TM Segment
        start_current = tm_indices[chain_id][-1][0]
        end_current = tm_indices[chain_id][-1][1]
        length_current = tm_indices[chain_id][-1][2]

        # check before anything else to save computation time
        if desired_length <= length_current:

            # If there is no elongation to do, we skip to the next segment
            # and the coordinates of the ith segment are not modified
            return 0

        # (i-1)th TM Segment
        end_previous = tm_indices[chain_id][-2][1]

        # Compute the number of residues that are required to elongate the current segment
        elongation_left_to_do = desired_length - length_current



        # Randomly choose the number of residues to elongate downstream ( toward the C-terminal )
        downstream = random.randint(0, elongation_left_to_do)

        lefts = None

        # The new end of this final tm should not exceed the protein length
        if downstream + end_current > protein_length:

            # Hence take everyting that is between the end of the current tm segment and the start of the next tm segment
            new_end_coordinates = protein_length

            # What is " left " from downstream that could not be taken because the protein is too short after the last tm ? 
            lefts = downstream - (protein_length - end_current)


        else:

            new_end_coordinates = end_current + downstream


        upstream = elongation_left_to_do - downstream
        if lefts:

            upstream += lefts


        # The new start of this tm segment should not be lower than the end of the previous tm segment
        if start_current - upstream < end_previous:

            new_start_coordinates = end_previous + 1 

        else:

            new_start_coordinates = start_current - upstream     


        tm_indices[chain_id][-1] =(new_start_coordinates, new_end_coordinates, new_end_coordinates - new_start_coordinates + 1)

    return 0

# Extract the sequence 

In [26]:
def segments_to_string(segments: dict, pdb_struct: dict):

    strings = {chain_id: "" for chain_id in segments.keys()}

    for chain_id, segments in segments.items():

        protein_length = pdb_struct["protein_length"][chain_id]

        # Initialize a list of zeros of length equal to the protein's length
        string_repr = ['0'] * protein_length

        # For each segment, replace the corresponding positions in the string with the segment's position
        for i, (start, end, _) in enumerate(segments):
            string_repr[start-1:end] = [str(i+1)] * (end - start + 1)

        # Join the list into a string
        string_repr = ''.join(string_repr)

        strings[chain_id] = string_repr

    return strings

def combine_strings_v2(segments_strings: dict, in_margin_binaries: dict):

    combined_strings = {chain_id: "" for chain_id in segments_strings.keys()}

    for chain_id in segments_strings:

        segments_string = segments_strings[chain_id]
        in_margin_binary = in_margin_binaries[chain_id]

        # Convert the strings to lists of integers
        segments_list = list(map(int, segments_string))
        in_margin_list = list(map(int, in_margin_binary))

        # Add the corresponding elements of the two lists if both are > 0, and put "X" otherwise
        combined_list = [str(a + b) if a > 0 and b > 0 else 'X' for a, b in zip(segments_list, in_margin_list)]

        # Convert the list back to a string
        combined_string = ''.join(combined_list)

        combined_strings[chain_id] = combined_string

    return combined_strings


def extract_coordinates_v3(combined_strings: dict):

    coordinates = {chain_id: [] for chain_id in combined_strings.keys()}
    for chain_id in combined_strings:

        combined_string = combined_strings[chain_id]
        print(combined_string)
        num_segments =  max(int(char) for char in combined_string if char.isdigit()) - 1

        for i in range(1, num_segments+1):

            # Initialize a list to store the stretches of the current segment
            stretches = []
            # Initialize variables to store the current start index and length
            start_index = None
            length = 0

            # Iterate over the string
            for j, char in enumerate(combined_string):
                # If the current character is the same as the segment
                if str(i+1) == char:
                    # If this is the start of a new stretch, update the start index
                    if start_index is None:
                        start_index = j
                    # Increase the length
                    length += 1
                # If the current character is different and there was a stretch of the segment
                elif start_index is not None:
                    # Add the stretch to the list
                    stretches.append((start_index+1, start_index+length, length))
                    # Reset the start index and length
                    start_index = None
                    length = 0

            # If the last character was part of a stretch, add the stretch to the list
            if start_index is not None:
                stretches.append((start_index+1, start_index+length, length))

            # If there were any stretches, add the longest one to the coordinates
            if stretches:
                longest_stretch = max(stretches, key=lambda x: x[2])
                coordinates[chain_id].append(longest_stretch[:2])

    return coordinates

In [27]:
random.seed(23111995)

pdb_struct = read_pdb(FILE_PATH)

mb_binary, margin_binary = return_binaries(pdb_struct)

tm_indices = extract_tm_segments_indices(mb_binary)

print(tm_indices)

segments_to_string(tm_indices, pdb_struct)
elongate_tm_segments(tm_indices, pdb_struct, min_length = 20)

second = segments_to_string(tm_indices, pdb_struct)

combined = combine_strings_v2(second, margin_binary)
final_coords = extract_coordinates_v3(combined)





Chain A has length 236
{'A': [(15, 39, 24), (48, 69, 21), (85, 105, 20), (111, 134, 23), (139, 160, 21), (179, 199, 20), (208, 231, 23)]}
XXXXXXXXXXXXXX2222222222222222222222222XXXXXXX333333333333333333333333333XXXXXXXXXXX4444444444444444444444555555555555555555555555555666666666666666666666666666677XXXXXXXXXXXX777777777777777777777777777777788888888888888888888888888XX8X


In [34]:
def extract_segments(segment_dict : dict, pdb_struct):

    extracted_segments = {}

    for chain_id, segments in segment_dict.items():
        
        if chain_id in pdb_struct["CA"]:
            
            chain = pdb_struct["CA"][chain_id] # Is a SuperOD
            
            for i, (start, end) in enumerate(segments):
                
                for data in chain[start:end].items():
                
                    print(data)

    return 0

extract_segments(final_coords, pdb_struct)

('16', {'coord': [10.07, -4.915, 15.297], 'res_name': 'T', 'res_number': '16'})
('17', {'coord': [11.46, -8.092, 13.754], 'res_name': 'L', 'res_number': '17'})
('18', {'coord': [8.989, -8.022, 10.876], 'res_name': 'W', 'res_number': '18'})
('19', {'coord': [9.307, -4.231, 10.732], 'res_name': 'L', 'res_number': '19'})
('20', {'coord': [12.977, -5.003, 10.414], 'res_name': 'G', 'res_number': '20'})
('21', {'coord': [12.9, -7.515, 7.54], 'res_name': 'I', 'res_number': '21'})
('22', {'coord': [10.425, -5.144, 5.892], 'res_name': 'G', 'res_number': '22'})
('23', {'coord': [12.909, -2.287, 5.746], 'res_name': 'T', 'res_number': '23'})
('24', {'coord': [15.61, -4.719, 4.791], 'res_name': 'L', 'res_number': '24'})
('25', {'coord': [13.375, -6.001, 2.002], 'res_name': 'L', 'res_number': '25'})
('26', {'coord': [11.843, -2.696, 0.993], 'res_name': 'M', 'res_number': '26'})
('27', {'coord': [15.281, -1.157, 0.507], 'res_name': 'L', 'res_number': '27'})
('28', {'coord': [16.625, -4.259, -1.292], 

0

In [32]:
pdb_struct["CA"]["A"][1:2]

SuperOD([('2',
          {'coord': [1.961, -16.242, 22.245],
           'res_name': 'A',
           'res_number': '2'})])

In [23]:
from collections import OrderedDict

class SlicableOrderedDict(OrderedDict):
    def __getitem__(self, key):
        if isinstance(key, slice):
            # Get the slice indices
            keys = list(self.keys())[key]
            return SlicableOrderedDict((k, self[k]) for k in keys)
        else:
            # If not a slice, use the standard __getitem__ from OrderedDict
            return OrderedDict.__getitem__(self, key)

# Example usage
my_od = SlicableOrderedDict([('a', 1), ('b', 2), ('c', 3), ('d', 4), ('e', 5)])



# Access by slice
print(my_od[1:4])  # Output: SlicableOrderedDict([('b', 2), ('c', 3), ('d', 4)])


SlicableOrderedDict([('b', 2), ('c', 3), ('d', 4)])
