In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pickle
from alphafold.common import residue_constants

In [2]:
# load feature dictionary
with open("features.pkl", "rb") as f:
    feature_dict = pickle.load(f)

feature_dict.keys()

dict_keys(['aatype', 'between_segment_residues', 'domain_name', 'residue_index', 'seq_length', 'sequence', 'deletion_matrix_int', 'msa', 'num_alignments', 'msa_species_identifiers', 'template_aatype', 'template_all_atom_masks', 'template_all_atom_positions', 'template_domain_names', 'template_sequence', 'template_sum_probs'])

In [3]:
feature_dict["residue_index"]

array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
       34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50,
       51, 52, 53, 54, 55])

In [46]:
def parse_fasta(fasta_string: str):
    """Parses FASTA string and returns list of strings with amino-acid sequences.

    Arguments:
      fasta_string: The string contents of a FASTA file.

    Returns:
      A tuple of two lists:
      * A list of sequences.
      * A list of sequence descriptions taken from the comment lines. In the
        same order as the sequences.
    """
    sequences = []
    descriptions = []
    index = -1
    for line in fasta_string.splitlines():
        line = line.strip()
        if line.startswith('>'):
            index += 1
            descriptions.append(line[1:])  # Remove the '>' at the beginning.
            sequences.append('')
            continue
        elif not line:
            continue  # Skip blank lines.
        sequences[index] += line

    return sequences, descriptions


def make_sequence_features(
    sequence: str, description: str, num_res: int):
    """Constructs a feature dict of sequence features."""
    features = {}
    features['aatype'] = residue_constants.sequence_to_onehot(
          sequence=sequence,
          mapping=residue_constants.restype_order_with_x,
          map_unknown_to_x=True)
    features['between_segment_residues'] = np.zeros((num_res,), dtype=np.int32)
    features['domain_name'] = np.array([description.encode('utf-8')],
                                        dtype=np.object_)
    features['residue_index'] = np.array(range(num_res), dtype=np.int32)
    features['seq_length'] = np.array([num_res] * num_res, dtype=np.int32)
    features['sequence'] = np.array([sequence.encode('utf-8')], dtype=np.object_)
    return features



def make_empty_msa_features(
    sequence: str, num_res: int):
    """Constructs a feature dict of empty MSA features."""
    int_msa = []
    num_alignments = 1
    species_ids = b''

    # Add the query sequence.
    int_msa.append([residue_constants.HHBLITS_AA_TO_ID[res] for res in sequence])


    features = {}
    features['deletion_matrix_int'] = np.zeros((1,num_res), dtype=np.int32)
    features['msa'] = np.array(int_msa, dtype=np.int32)
    features['num_alignments'] = np.array(
        [num_alignments] * num_res, dtype=np.int32)
    features['msa_species_identifiers'] = np.array(species_ids, dtype=np.object_)
    return features


TEMPLATE_FEATURES = {
    'template_aatype': np.float32,
    'template_all_atom_masks': np.float32,
    'template_all_atom_positions': np.float32,
    'template_domain_names': object,
    'template_sequence': object,
    'template_sum_probs': np.float32,
}


def make_empty_template_features(
    sequence: str, num_res: int):
    """Constructs a feature dict of empty template features."""
    
    template_features = {}
    for template_feature_name in TEMPLATE_FEATURES:
        template_features[template_feature_name] = []

    for name in template_features:
        # Make sure the feature has correct dtype even if empty.
        template_features[name] = np.array([], dtype=TEMPLATE_FEATURES[name])

    return template_features

In [47]:
fasta_path = "input/mono_set1/GA98.fasta"

with open(fasta_path) as f:
    input_fasta_str = f.read()
    input_seqs, input_descs = parse_fasta(input_fasta_str)
    if len(input_seqs) != 1:
      raise ValueError(
          f'More than one input sequence found in {fasta_path}.')
    input_sequence = input_seqs[0]
    input_description = input_descs[0].split()[0]
    num_res = len(input_sequence)

# Generate feature: sequence part
feature_test = make_sequence_features(
               sequence=input_sequence,
               description=input_description,
               num_res=len(input_sequence))

# Generate feature: MSA part
feature_test.update(make_empty_msa_features(
                sequence=input_sequence,
                num_res=len(input_sequence)))

# Generate feature: template part
feature_test.update(make_empty_template_features(
                sequence=input_sequence,
                num_res=len(input_sequence)))

In [None]:
np.array_equiv(feature_dict["msa_species_identifiers"], feature_test["msa_species_identifiers"])

True

In [None]:
# features_output_path = os.path.join(output_dir, 'features.pkl')
# with open(features_output_path, 'wb') as f:
#     pickle.dump(feature_dict, f, protocol=4)