In [1]:
from extract_contacts_and_attentions import *

ESM2(
  (embed_tokens): Embedding(33, 1280, padding_idx=1)
  (layers): ModuleList(
    (0-32): 33 x TransformerLayer(
      (self_attn): MultiheadAttention(
        (k_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (q_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (out_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (rot_emb): RotaryEmbedding()
      )
      (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      (fc1): Linear(in_features=1280, out_features=5120, bias=True)
      (fc2): Linear(in_features=5120, out_features=1280, bias=True)
      (final_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
    )
  )
  (contact_head): ContactPredictionHead(
    (regression): Linear(in_features=660, out_features=1, bias=True)
    (activation): Sigmoid()
  )
  (emb_layer_norm_after): LayerNorm((1280,), eps=1

In [2]:
warnings.simplefilter('ignore', PDBConstructionWarning)
parser = PDBParser()  

In [3]:
def check_sequences(pdb_id):  
    # Parse pdb file and save as structure. The pdb file is where we are getting CA coordinates from.
    structure = parser.get_structure(pdb_id, f"{structure_dir+pdb_id}.pdb")

    # Extract desired protein structure from PDB structure (typically only 1 structure to choose from)
    protein_structure = structure[0]

    residue_position = 0
    mismatches = 0
#     print(pdb_id)
    if 'A' in protein_structure:
        for residue in protein_structure['A']:
            if 'CA' in residue:
                if residue_position < len(protein_data[pdb_id]):
                    if simple_aa(residue.resname) != protein_data[pdb_id][residue_position]:
#                         print(residue.id[1], simple_aa(residue.resname), protein_data[pdb_id][residue_position])
                        mismatches+=1
            residue_position+=1
        if mismatches == 0:
#             same_sequence_ids.append(pdb_id)
            return pdb_id


In [4]:
def check_casp_pdb_seqs(protein_data):
    same_sequence_ids = []
    iterations = 0
    for pdb_id, protein_sequence in list(protein_data.items()):
#         print("Iterations: ", iterations)
#         print('Sequence ID: ', pdb_id)
#         check_sequences(pdb_id)
#         print('No mismatches in pdb sequence and pulled sequence: ', check_sequences(pdb_id))
        if check_sequences(pdb_id) == None:
            continue
        else:
            same_sequence_ids.append(pdb_id)
        iterations+=1
    return same_sequence_ids

In [5]:
def calc_contact_sites(pdb_id, in_contact_sites, non_contact_sites, subset_non_contact_sites):
    parser = PDBParser()
    structure = parser.get_structure(pdb_id, f"{structure_dir}/{pdb_id}.pdb")  # Ensure correct path joining
    protein_structure = structure[0]
    chain = protein_structure['A']

    # Initialize count variable
    count = 0

    for i, residue1 in enumerate(chain):
        for j, residue2 in enumerate(chain):
            if i <= j:
                continue # Avoids redundant comparisons and self-comparison
            if residue1.id[1] > len(protein_data[pdb_id]) or residue2.id[1] > len(protein_data[pdb_id]):
                continue
            try:
                distance = abs(residue1['CA'] - residue2['CA'])
            except KeyError:
                continue
            if distance < 5:
                if abs(residue1.id[1] - residue2.id[1]) > 2:
    #                 print(residue1.id[1], residue1.resname, residue2.id[1], residue2.resname, distance)
                    in_contact_sites[pdb_id].append({
                        'res_1': residue1.id[1], 
                        'res_2': residue2.id[1], 
                        'sig_1': simple_aa(residue1.resname), 
                        'sig_2': simple_aa(residue2.resname), 
                        'dist': distance,
                        'in_contact': True
                    })
                    count += 1
            else:
                if abs(residue1.id[1] - residue2.id[1]) > 2:
                    non_contact_sites[pdb_id].append({
                        'res_1': residue1.id[1], 
                        'res_2': residue2.id[1], 
                        'sig_1': simple_aa(residue1.resname), 
                        'sig_2': simple_aa(residue2.resname), 
                        'dist': distance,
                        'in_contact': False
                    })

    if non_contact_sites[pdb_id]:
        subset_non_contact_sites[pdb_id] = random.sample(non_contact_sites[pdb_id], min(len(non_contact_sites[pdb_id]), len(in_contact_sites[pdb_id])))

    # Optionally print or process the results
    return f"Total contacts found {pdb_id}: {count}"

In [6]:
def contacts_per_pdb(same_sequence_ids):

    in_contact_sites = defaultdict(list)
    non_contact_sites = defaultdict(list)
    subset_non_contact_sites = defaultdict(list)

    iterations = 0

    for pdb_id in same_sequence_ids:
        calc_contact_sites(pdb_id, in_contact_sites, non_contact_sites, subset_non_contact_sites)
#         print(calc_contact_sites(pdb_id, in_contact_sites, non_contact_sites, subset_non_contact_sites))
#         print(len(in_contact_sites[pdb_id]), len(subset_non_contact_sites[pdb_id]))
#         print("Iterations: ", iterations)
        iterations+=1
        
    return in_contact_sites, non_contact_sites, subset_non_contact_sites

In [7]:
# Initialize contact_data as a defaultdict of lists

def generate_contact_data(in_contact_sites, subset_non_contact_sites):

    contact_data = defaultdict(list)

    # Add data from in_contact_sites
    for pdb_id, contacts in in_contact_sites.items():
        for contact in contacts:
            contact_data[pdb_id].append({
                'res_1': contact['res_1'],
                'res_2': contact['res_2'],
                'sig_1': contact['sig_1'],
                'sig_2': contact['sig_2'],
                'dist': contact['dist'],
                'in_contact': contact['in_contact']
            })

    # Add data from subset_non_contact_sites
    for pdb_id, non_contacts in subset_non_contact_sites.items():
        for non_contact in non_contacts:
            contact_data[pdb_id].append({
                'res_1': non_contact['res_1'],
                'res_2': non_contact['res_2'],
                'sig_1': non_contact['sig_1'],
                'sig_2': non_contact['sig_2'],
                'dist': non_contact['dist'],
                'in_contact': non_contact['in_contact']
            })
    return contact_data


In [8]:
# Index protein sequence as sequence 0 (next sequence would be indexed as 1)

def generate_embeddings(pdb_id):
    protein_sequence = protein_data[pdb_id]
    esm_input_data = [(0, protein_sequence)]
    # print('Data: ', esm_input_data, '\n')

    # Prepare variables to input sequence into ESM-2 model 
    batch_converter = alphabet.get_batch_converter()
    batch_labels, batch_strs, batch_tokens = batch_converter(esm_input_data)
    batch_tokens = batch_tokens.cuda() if torch.cuda.is_available() else batch_tokens

    # print('batch_tokens: ', '\n\n', batch_tokens, '\n')

    # 4. Input prepared sequence information into model and output as results (contact predictions are included in embedding)
    with torch.no_grad():
        results = model(batch_tokens, repr_layers=[33], return_contacts=True)

    return results['attentions']

In [9]:
# Extract attentions from all heads and layers for given amino acid residues

def get_x_y(attention_data, res_1, res_2):
    vectors = []
    for layer in range(0,33):
        for head in range(0,20):
            vectors.append(attention_data[0][layer][head][res_1][res_2])

    return vectors


In [10]:
def output_x_y(sequence_ids, contact_data):
    X = []
    y = []
    iterations = 0

    for pdb_id in sequence_ids:
        structure = parser.get_structure(pdb_id, f"{structure_dir}/{pdb_id}.pdb")  # Ensure correct path joining
        protein_structure = structure[0]
        chain = protein_structure['A']
        first_residue = list(chain.get_residues())[0].id[1]
        print('Iteration: ', iterations)
        iterations+=1
        if first_residue == 1:
            attention_data = generate_embeddings(pdb_id)
            for i in contact_data[pdb_id]:
                    X.append(get_x_y(attention_data, i['res_1'], i['res_2']))
                    y.append(i['in_contact'])
            else:
                continue
    return X,y



In [11]:
# Get sequences from CASP7 file
prot_data_dict = parse_casp7_file(casp_95)

protein_id_counts = Counter(protein_id.split('_')[0] for protein_id in prot_data_dict)
single_occurence_ids = [protein_id for protein_id, count in protein_id_counts.items() if count == 1]

# protein_data = generate_fastas(single_occurence_ids)
protein_data = load_fastas(fasta_dir)
same_sequence_ids = check_casp_pdb_seqs(protein_data)
in_contact_sites, non_contact_sites, subset_non_contact_sites = contacts_per_pdb(same_sequence_ids)
contact_data = generate_contact_data(in_contact_sites, subset_non_contact_sites)


In [13]:
seed_value = 67
random.seed(seed_value)
n_sequences = 300

# sequence_ids = random.sample(same_sequence_ids, n_sequences)
sequence_ids = same_sequence_ids
print(sequence_ids[:5])

['1UE8', '1F8E', '1J5B', '1P0Z', '1N3K']


## Linear SVC

In [14]:
from sklearn.svm import LinearSVC


In [None]:
X, y = output_x_y(sequence_ids, contact_data)

Iteration:  0
Iteration:  1
Iteration:  2
Iteration:  3
Iteration:  4
Iteration:  5
Iteration:  6
Iteration:  7
Iteration:  8
Iteration:  9
Iteration:  10
Iteration:  11
Iteration:  12
Iteration:  13
Iteration:  14
Iteration:  15
Iteration:  16
Iteration:  17
Iteration:  18
Iteration:  19
Iteration:  20
Iteration:  21
Iteration:  22
Iteration:  23
Iteration:  24
Iteration:  25
Iteration:  26
Iteration:  27
Iteration:  28
Iteration:  29
Iteration:  30
Iteration:  31
Iteration:  32
Iteration:  33
Iteration:  34
Iteration:  35
Iteration:  36
Iteration:  37
Iteration:  38
Iteration:  39
Iteration:  40
Iteration:  41
Iteration:  42
Iteration:  43
Iteration:  44
Iteration:  45
Iteration:  46
Iteration:  47
Iteration:  48
Iteration:  49
Iteration:  50
Iteration:  51
Iteration:  52
Iteration:  53
Iteration:  54
Iteration:  55
Iteration:  56
Iteration:  57
Iteration:  58
Iteration:  59
Iteration:  60
Iteration:  61
Iteration:  62
Iteration:  63
Iteration:  64
Iteration:  65
Iteration:  66
Itera

Iteration:  519
Iteration:  520
Iteration:  521
Iteration:  522
Iteration:  523
Iteration:  524
Iteration:  525
Iteration:  526
Iteration:  527
Iteration:  528
Iteration:  529
Iteration:  530
Iteration:  531
Iteration:  532
Iteration:  533
Iteration:  534
Iteration:  535
Iteration:  536
Iteration:  537
Iteration:  538
Iteration:  539
Iteration:  540
Iteration:  541
Iteration:  542
Iteration:  543
Iteration:  544


In [None]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size= 0.4, random_state=55)

In [None]:
linear_svc = LinearSVC(random_state=55)
linear_svc.fit(X_train, y_train)

In [None]:
y_pred = linear_svc.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print(f"Accuracy: {accuracy:.2f}")