# esm-2 distogram predictor

## Distogram

In [None]:
!pip install biopython



In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from Bio.PDB import PDBParser, NeighborSearch, PDBList
from Bio.PDB.Polypeptide import is_aa

In [None]:
from google.colab import drive
drive.mount('/content/drive')
import os
os.chdir('/content/drive/MyDrive/flamingo-pep-gen')

Mounted at /content/drive


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from Bio.PDB import PDBParser, is_aa
import os
import glob

parent_dir = 'pinder-dataset/2023-11/pdbs/2023-11/pdbs/'
pdb_files = glob.glob(os.path.join(parent_dir, '*.pdb'))
# pdb_files = glob.glob('pinder-dataset/2023-11/pdbs/2023-11/pdbs/*')


In [None]:
len(pdb_files)

In [None]:
# pdb_files = [f for f in os.listdir(parent_dir) if os.path.isfile(os.path.join(parent_dir, f))]

In [None]:
len(pdb_files)

In [None]:
save_path = '/content/drive/MyDrive/flamingo-pep-gen/pinder-dataset/2023-11/distograms-10A/'

In [None]:
# os.mkdir('/content/drive/MyDrive/flamingo-pep-gen/pinder-dataset/2023-11/distograms-10A/')

In [None]:
save_path = '/content/drive/MyDrive/flamingo-pep-gen/pinder-dataset/2023-11/distograms-10A/'
# Counter for processed files
processed_files_count = 0

def process_pdb_file(pdb_file_path, save_path):
    global processed_files_count

    pdb_id = pdb_file_path.split('/')[-1].split('.')[0]  # Extract pdb_id from file path

    # Check if the files have already been processed
    matrix_file = os.path.join(save_path, pdb_id + "_matrix.npy")
    distogram_image_file = os.path.join(save_path, pdb_id + "_distogram.png")
    contact_map_image_file = os.path.join(save_path, pdb_id + "_contact_map.png")

    if os.path.exists(matrix_file) and os.path.exists(distogram_image_file) and os.path.exists(contact_map_image_file):
        print(f"Skipping already processed file: {pdb_file_path}")
        print('processing_files_count:',processed_files_count)
        processed_files_count += 1
        return

    print(f"Processing PDB file at {pdb_file_path}")

    # Dictionary to store residue information
    residue_info = {}
    distance_results = {}

    # Load the PDB file using a PDBParser
    pdb_parser = PDBParser(QUIET=True)
    pdb_id = pdb_file_path.split('/')[-1].split('.')[0]  # Extract pdb_id from file path
    structure = pdb_parser.get_structure(pdb_id, pdb_file_path)

    # Process each chain and residue in the PDB file
    for chain in structure.get_chains():
        chain_id = chain.get_id()
        for residue in chain:
            if is_aa(residue, standard=True):
                residue_id = residue.get_id()[1]
                residue_key = (chain_id, residue_id)
                if "CA" in residue:
                    residue_info[residue_key] = {
                        "coordinates": residue["CA"].get_coord()
                    }

    # Calculate distances between residues
    for residue_key1, info1 in residue_info.items():
        for residue_key2, info2 in residue_info.items():
            if residue_key1 != residue_key2:
                distance = np.linalg.norm(info1["coordinates"] - info2["coordinates"])
                chain_dist = "Different chains" if residue_key1[0] != residue_key2[0] else "Same chain"
                distance_results[(residue_key1, residue_key2)] = (distance, chain_dist)

    # Interaction calculation and plotting
    # Initialize distance matrix and chain boundaries for plotting
    num_residues = len(residue_info)
    dist_matrix = np.zeros((num_residues, num_residues))
    chain_boundaries = []
    prev_chain = None
    for i, (residue_key1, info1) in enumerate(residue_info.items()):
        if prev_chain != residue_key1[0]:
            if prev_chain is not None:
                chain_boundaries.append(i)
            prev_chain = residue_key1[0]
        for j, (residue_key2, info2) in enumerate(residue_info.items()):
            if residue_key1 != residue_key2:
                dist_matrix[i, j] = distance_results[(residue_key1, residue_key2)][0]

    # Save the distogram matrix
    matrix_filename = os.path.join(save_path, pdb_id + "_matrix.npy")
    np.save(matrix_filename, dist_matrix)

    # Plot and save the distogram
    plt.imshow(dist_matrix, cmap="viridis", origin="lower", extent=[0, num_residues, 0, num_residues])
    plt.colorbar(label="Distance (Å)")
    plt.title(f"Residue-Residue Distogram for {pdb_id}")
    distogram_image_path = os.path.join(save_path, pdb_id + "_distogram.png")
    plt.savefig(distogram_image_path)
    plt.close()

    # Plot and save the binary contact map
    interaction_distance_threshold = 10
    plt.imshow((dist_matrix < interaction_distance_threshold), cmap="viridis", origin="lower", extent=[0, num_residues, 0, num_residues])
    plt.title(f"Residue-Residue Binary Contact Map for {pdb_id}")
    contact_map_image_path = os.path.join(save_path, pdb_id + "_contact_map.png")
    plt.savefig(contact_map_image_path)
    plt.close()

    print('processing_files_count:',processed_files_count)
    processed_files_count += 1


In [None]:

# Loop through the PDB files and process each one
for pdb_file_path in pdb_files:
    process_pdb_file(pdb_file_path, save_path)


## concatenated ESM-2

In [None]:
import os
import pandas as pd

# Define the paths
metadata_path = '/content/metadata.csv'
distogram_dir = '/content/drive/MyDrive/distograms-10A'

# Load the metadata CSV
metadata = pd.read_csv(metadata_path)

# Extract IDs from metadata
metadata_ids = set(metadata['id'].tolist())


In [None]:
len(metadata_ids)

239098

In [None]:
# List files in the distogram directory and extract PDB IDs
distogram_files = os.listdir(distogram_dir)
distogram_ids = {file.split('_distogram.png')[0] for file in distogram_files if '--' in file}

# Check for intersection
intersecting_ids = metadata_ids.intersection(distogram_ids)

# Output the result
print(f"Number of matching IDs: {len(intersecting_ids)}")
print("Matching IDs:", intersecting_ids)


Number of matching IDs: 4697
Matching IDs: {'6kpa__F1_Q97YJ9--6kpa__G1_Q97YJ9', '4fmg__C1_C0JPK1--4fmg__D1_C0JPK1', '7yy4__B2_B1MDL6--7yy4__C2_B1MDL6', '4lg2__A1_Q8JPY0--4lg2__B1_Q8JPY0', '4jk1__C1_P0A8V2--4jk1__F1_P00579', '6ck7__A1_Q5ZSC4--6ck7__B1_Q5ZSC4', '7p5x__C1_P60281--7p5x__E1_A0QWT1', '6j5a__B1_A0A287B4I0--6j5a__D1_Q95339', '3h87__C2_O07227--3h87__D2_O07227', '4p18__G1_P07798--4p18__W1_P07798', '3u8q__A1_P24627--3u8q__B1_P24627', '4p72__A1_Q9I0A4--4p72__C1_Q9I0A3', '5bse__B1_G7KRM5--5bse__D1_G7KRM5', '6phe__D1_A0A077EEZ5--6phe__F1_A0A077EEZ5', '6rb9__A1_Q02307--6rb9__B1_Q02307', '7yu8__B1_P63096--7yu8__C1_P54311', '2jdm__B1_Q9HYN5--2jdm__C1_Q9HYN5', '1z88__A1_Q540U1--1z88__B1_Q540U1', '1ezv__D2_P07143--1ezv__I2_P22289', '3hi4__A1_P22862--3hi4__B1_P22862', '2ooz__A1_P14174--2ooz__B1_P14174', '3ju9__A2_P55915--3ju9__A4_P55915', '7ylv__C1_P39076--7ylv__D1_P39076', '3rk3__A1_P63027--3rk3__B1_P32851', '2qin__B1_P52700--2qin__D1_P52700', '2ag0__A1_Q9F4L3--2ag0__B1_Q9F4L3', '6x2m__B

In [None]:
intersecting_ids_df = pd.DataFrame(list(intersecting_ids), columns=['id'])
output_csv_path = '/content/intersecting_ids.csv'  # Define path for the output CSV

# Save the DataFrame to a CSV file
intersecting_ids_df.to_csv(output_csv_path, index=False)

print(f"Intersecting IDs saved to {output_csv_path}")

Intersecting IDs saved to /content/intersecting_ids.csv


In [None]:
distogram_ids

In [None]:
!pip install fair-esm

Collecting fair-esm
  Downloading fair_esm-2.0.0-py3-none-any.whl (93 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/93.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m92.2/93.1 kB[0m [31m2.6 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m93.1/93.1 kB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: fair-esm
Successfully installed fair-esm-2.0.0


In [None]:
import pandas as pd
import torch
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
from transformers import T5ForConditionalGeneration, T5Tokenizer
import esm
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer

# Load ESM-2 model for embeddings
esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = alphabet.get_batch_converter()
esm_model.eval()
if torch.cuda.is_available():
    esm_model.cuda()

Downloading: "https://dl.fbaipublicfiles.com/fair-esm/models/esm2_t33_650M_UR50D.pt" to /root/.cache/torch/hub/checkpoints/esm2_t33_650M_UR50D.pt
Downloading: "https://dl.fbaipublicfiles.com/fair-esm/regression/esm2_t33_650M_UR50D-contact-regression.pt" to /root/.cache/torch/hub/checkpoints/esm2_t33_650M_UR50D-contact-regression.pt


In [None]:
import torch
from torch.utils.data import Dataset
import esm

class ProteinPairsDataset(Dataset):
    def __init__(self, fasta_file):
        self.sequence_data = {}
        self.load_sequences(fasta_file)

    def load_sequences(self, fasta_file):
        current_header = ""
        with open(fasta_file, 'r') as f:
            for line in f:
                if line.startswith('>'):
                    current_header = line[1:].strip()
                else:
                    seq1, seq2 = line.strip().split(',')
                    self.sequence_data[current_header] = (seq1, seq2)

    def __len__(self):
        return len(self.sequence_data)

    def __getitem__(self, idx):
        header = list(self.sequence_data.keys())[idx]
        seq1, seq2 = self.sequence_data[header]
        emb1 = self.generate_esm_embedding(seq1)
        emb2 = self.generate_esm_embedding(seq2)
        concatenated_embedding = torch.cat((emb1, emb2), dim=0)
        return {header: concatenated_embedding}

    def generate_esm_embedding(self, seq):
        batch_labels, batch_strs, batch_tokens = batch_converter([("", seq)])
        batch_tokens = batch_tokens.to('cuda' if torch.cuda.is_available() else 'cpu')
        with torch.no_grad():
            results = esm_model(batch_tokens, repr_layers=[33])
        token_representations = results["representations"][33]
        seq_len = (batch_tokens != alphabet.padding_idx).sum(1).item()
        return token_representations[0, 1:seq_len-1].cpu()


In [None]:

# Usage example
fasta_file = '/content/sequences_esm-26.fasta'
dataset = ProteinPairsDataset(fasta_file)

# Accessing the first item in the dataset
first_item = dataset[0]
print(first_item)


{'6kpa__F1_Q97YJ9--6kpa__G1_Q97YJ9': tensor([[ 0.3139, -0.0509, -0.1930,  ...,  0.0281, -0.1715, -0.0755],
        [ 0.1223, -0.0933, -0.3406,  ..., -0.0361,  0.0551, -0.1054],
        [-0.1298,  0.0771, -0.1424,  ...,  0.0712,  0.0447, -0.1218],
        ...,
        [ 0.1815, -0.1354, -0.0028,  ..., -0.1126,  0.0451, -0.1150],
        [ 0.1240, -0.1464, -0.0027,  ..., -0.0037,  0.0707, -0.0379],
        [ 0.0795, -0.0415, -0.2137,  ..., -0.1478, -0.1127,  0.0556]])}


In [None]:
print(len(dataset))
print(dataset[0]['6kpa__F1_Q97YJ9--6kpa__G1_Q97YJ9'].shape)

4697
torch.Size([656, 1280])


In [None]:
# Process the dataset to store embeddings
processed_data = {}
for idx in tqdm(range(len(dataset))):
    processed_data.update(dataset[idx])

100%|██████████| 4697/4697 [21:48<00:00,  3.59it/s]


In [None]:
import pickle
with open('protein_pairs_dataset.pkl', 'wb') as file:
    pickle.dump(processed_data, file)

In [None]:
!pip install boto3
!pip install awscli

Collecting boto3
  Downloading boto3-1.34.8-py3-none-any.whl (139 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m139.3/139.3 kB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting botocore<1.35.0,>=1.34.8 (from boto3)
  Downloading botocore-1.34.8-py3-none-any.whl (11.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m11.9/11.9 MB[0m [31m82.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting jmespath<2.0.0,>=0.7.1 (from boto3)
  Downloading jmespath-1.0.1-py3-none-any.whl (20 kB)
Collecting s3transfer<0.11.0,>=0.10.0 (from boto3)
  Downloading s3transfer-0.10.0-py3-none-any.whl (82 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m82.1/82.1 kB[0m [31m13.0 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: jmespath, botocore, s3transfer, boto3
Successfully installed boto3-1.34.8 botocore-1.34.8 jmespath-1.0.1 s3transfer-0.10.0
Collecting awscli
  Downloading awscli-1.32.8-py3-none-any.whl (4.3 MB)
[2K     [90

In [None]:
import boto3
from botocore.handlers import disable_signing
from botocore.exceptions import ClientError

# Initialize a Boto3 session without AWS credentials
s3 = boto3.resource('s3')
s3.meta.client.meta.events.register('choose-signer.s3.*', disable_signing)

bucket_name = 'openfold'
example_id = '6kpa'
example_chain = 'F'  # Replace with the correct chain if different

# Example file path construction for an MSA file
file_path = f"openfold/pdb/{example_id}_{example_chain}/a3m/bfd_uniclust_hits.a3m"  # Adjust based on actual S3 structure
local_path = f"{example_id}_{example_chain}.bfd_uniclust_hits.a3m"

# Attempt to download the file
try:
    s3.Bucket(bucket_name).download_file(file_path, local_path)
    print(f"Downloaded {file_path} to {local_path}")
except ClientError as e:
    print(f"Error downloading {file_path}: {e}")
except Exception as e:
    print(f"An error occurred: {e}")



Error downloading openfold/pdb/6kpa_F/a3m/bfd_uniclust_hits.a3m: An error occurred (404) when calling the HeadObject operation: Not Found


In [None]:
!aws s3 cp --no-sign-request s3://openfold/pdb/7rxc_N/a3m/bfd_uniclust_hits.a3m ./7xrc-openfold

Completed 256.0 KiB/422.3 KiB (973.0 KiB/s) with 1 file(s) remainingCompleted 422.3 KiB/422.3 KiB (1.6 MiB/s) with 1 file(s) remaining  download: s3://openfold/pdb/7rxc_N/a3m/bfd_uniclust_hits.a3m to ./7xrc-openfold


In [None]:
!aws s3 cp --no-sign-request s3://openfold/pdb/101m_A/a3m/bfd_uniclust_hits.a3m ./openfold_101m_A/

Completed 256.0 KiB/269.6 KiB (674.2 KiB/s) with 1 file(s) remainingCompleted 269.6 KiB/269.6 KiB (708.7 KiB/s) with 1 file(s) remainingdownload: s3://openfold/pdb/101m_A/a3m/bfd_uniclust_hits.a3m to openfold_101m_A/bfd_uniclust_hits.a3m


## MSA-Header-Sequence Mapping

In [None]:
!pip install biopython

Collecting biopython
  Downloading biopython-1.82-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.1/3.1 MB[0m [31m27.8 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: biopython
Successfully installed biopython-1.82


In [None]:
import pandas as pd
import pickle
from Bio import SeqIO

# Load the sequence to PDB ID mapping dictionary
with open('seq_to_msa_id_dict.pkl', 'rb') as file:
    seq_to_msa_id_dict = pickle.load(file)

# Load the sequences from the fasta file
fasta_file = 'sequences_esm-26.fasta'
sequence_pairs = []

# Reading the fasta file to extract sequence pairs
with open(fasta_file, 'r') as fasta:
    for record in SeqIO.parse(fasta, 'fasta'):
        header = record.description
        sequences = str(record.seq).split(',')
        if len(sequences) == 2:
            sequence_pairs.append((header, sequences[0], sequences[1]))



In [None]:
# Creating a dataframe with columns for Header, Sequence 1, Sequence 2, PDB ID 1, and PDB ID 2
data = []

# Iterate over sequence pairs and map to PDB IDs
for header, seq1, seq2 in sequence_pairs:
    pdb1 = seq_to_msa_id_dict.get(seq1, 'Unknown')
    pdb2 = seq_to_msa_id_dict.get(seq2, 'Unknown')
    data.append([header, seq1, seq2, pdb1, pdb2])

# Create a DataFrame
df = pd.DataFrame(data, columns=['Header', 'Sequence1', 'Sequence2', 'PDB1', 'PDB2'])
print(df.head())

                             Header  \
0  6kpa__F1_Q97YJ9--6kpa__G1_Q97YJ9   
1  4fmg__C1_C0JPK1--4fmg__D1_C0JPK1   
2  7yy4__B2_B1MDL6--7yy4__C2_B1MDL6   
3  4lg2__A1_Q8JPY0--4lg2__B1_Q8JPY0   
4  4jk1__C1_P0A8V2--4jk1__F1_P00579   

                                           Sequence1  \
0  DKTVLDANLDPLKGKTIGVIGYGNQGRVQATIMRENGLNVIVGNVK...   
1  VEVLSVVTGEDSITQIELYLNPRMGVNSPDLPTTSNWYTYTYDLQP...   
2  MTGAVCPGSFDPVTLGHLDVFERAAAQFDEVIVAVLINPNKAGMFT...   
3  ISAKDLKEIMYDHLPGFGTAFHQLVQVICKIGKDNNLLDTIHAEFQ...   
4  KKRIRKDFGKRPQVLDVPYLLSIQLDSFQKFIEQDPEGQYGLEAAF...   

                                           Sequence2     PDB1     PDB2  
0  DKTVLDANLDPLKGKTIGVIGYGNQGRVQATIMRENGLNVIVGNVK...  Unknown  Unknown  
1  VEVLSVVTGEDSITQIELYLNPRMGVNSPDLTSNWYTYTYDLQPKG...  Unknown  Unknown  
2  MTGAVCPGSFDPVTLGHLDVFERAAAQFDEVIVAVLINPAGMFTVD...  Unknown  Unknown  
3  ISAKDLKEIMYDHLPGFGTAFHQLVQVICKIGKDNNLLDTIHAEFQ...  Unknown  Unknown  
4  QSQLKLLVTRGKEQGYLTYAEVNDHLPEDIVDSDQIEDIIQMINDM...  Unknown  

In [None]:
# Counting the number of rows where both PDB1 and PDB2 are not 'Unknown'
count_non_empty_pdb = df[(df['PDB1'] != 'Unknown') & (df['PDB2'] != 'Unknown')].shape[0]
count_non_empty_pdb


902

In [None]:
filtered_df = df[(df['PDB1'] != 'Unknown') & (df['PDB2'] != 'Unknown') & (df['Sequence1'] != df['Sequence2'])]
filtered_df

Unnamed: 0,Header,Sequence1,Sequence2,PDB1,PDB2
7,6j5a__B1_A0A287B4I0--6j5a__D1_Q95339,PFDQMTIEDLNEVFPETKLDKKKY,ASVVPLKDRRLLEVKLGELPSWILMRDFTPSGIAGAFQRGYYRYYN...,[6j54_d],[6j54_f]
10,3u8q__A1_P24627--3u8q__B1_P24627,YTRVVWCAVGPEEQKKCQQWSQQSGQNVTCATASTTDDCIVLVLKG...,LEACAF,[1sdx_A],[3tod_B]
18,1ezv__D2_P07143--1ezv__I2_P22289,MTAAEHGLHAPAYAWSHNGPFETFDHASIRRGYQVYREVCAACHSL...,SSLYKTFFKRNAVFVGTIFAGAFVFQTVFDTAITSWYENHNKGKLW...,[1ezv_D],[1ezv_I]
47,5yb2__B1_P04578--5yb2__E1_UNDEFINED,SGIVQQQNNLLRAIEAQQHLLQLTVWGIKQLQARIL,ELTWEEWEKKIEEYTKKIEEILK,[3g7a_A],[5yb2_G]
68,7azf__A1_P0A988--7azf__E1_UNDEFINED,GSHMKFTVEREHLLKPLQQVSGPLGGRPTLPILGNLLLQVADGTLS...,XQADLF,[7azf_A],[3q4k_C]
...,...,...,...,...,...
4628,7fiz__A1_A0A059VAZ3--7fiz__B1_A0A059VAZ3,KEFEVLSFEIDEQALAFDVDNIEMVIEKSDITPVPKSRHFVEGVIN...,GSHMKDVQTETFSVAESIEEISKANEEITNQLLGISKEMDNISTRI...,[3ja6_A],[2ch7_A]
4639,4ysn__A1_M1GRN3--4ysn__D1_M1GRN3,MQLNSTEISELIKQRIAQFNVVSEAHNEGTIVSVSDGVIRIHGLAD...,MNLNATILGQAIAFVLFVLFAMKYVWPPLMAAIEKRQKEIADGLAS...,[3oaa_A],[6oqt_X]
4658,6cuf__C1_UNDEFINED--6cuf__U1_Q2N0S6,TENFNMWKNDMVEQMHEDIISLWDQSLKPCVKLTP,GSDTITLPCRIKQIINMWQKVGKAMYAPPISGQIRCSSNITGLLLT...,[3dnl_A],[3dnl_C]
4688,7x1u__C1_UNDEFINED--7x1u__D1_P62871,XSLSDKDKAAVRALWSKIGKSSDAIGNDALSRMIVVYPQTKIYFSH...,VEWTDKERSIISDIFSHMDYDDIGPKALSRCLVVYPWTQRYFSGFG...,[1la6_A],[1la6_B]


In [None]:
# have distograms (gdrive) + corresponding pdbs (gdrive/dcc) + msas (gdrive) + sequences (gdrive)

In [None]:
# get msas downloaded -- primer: !aws s3 cp --no-sign-request s3://openfold/pdb/101m_A/a3m/bfd_uniclust_hits.a3m ./openfold_101m_A/

In [None]:
# Flatten the DataFrame columns if they contain lists
filtered_df['PDB1'] = filtered_df['PDB1'].apply(lambda x: x[0] if isinstance(x, list) else x)
filtered_df['PDB2'] = filtered_df['PDB2'].apply(lambda x: x[0] if isinstance(x, list) else x)
unique_pdbs = set(filtered_df['PDB1']).union(set(filtered_df['PDB2']))

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  filtered_df['PDB1'] = filtered_df['PDB1'].apply(lambda x: x[0] if isinstance(x, list) else x)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  filtered_df['PDB2'] = filtered_df['PDB2'].apply(lambda x: x[0] if isinstance(x, list) else x)


In [None]:
!pip install awscli

Collecting awscli
  Downloading awscli-1.32.9-py3-none-any.whl (4.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.3/4.3 MB[0m [31m14.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting botocore==1.34.9 (from awscli)
  Downloading botocore-1.34.9-py3-none-any.whl (11.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m11.9/11.9 MB[0m [31m38.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting docutils<0.17,>=0.10 (from awscli)
  Downloading docutils-0.16-py2.py3-none-any.whl (548 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m548.2/548.2 kB[0m [31m45.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting s3transfer<0.11.0,>=0.10.0 (from awscli)
  Downloading s3transfer-0.10.0-py3-none-any.whl (82 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m82.1/82.1 kB[0m [31m12.6 MB/s[0m eta [36m0:00:00[0m
Collecting colorama<0.4.5,>=0.2.5 (from awscli)
  Downloading colorama-0.4.4-py2.py3-none-any.whl (16 kB)
Collecti

In [None]:
import os
import subprocess

output_dir = "/content/drive/MyDrive/flamingo-pep-gen/data_dump/openfold_esm-26/"

for pdb_id in unique_pdbs:
    # Assuming the format 'xxxx_Y' where xxxx is the PDB ID and Y is the chain
    pdb, chain = pdb_id.split('_')
    s3_path = f"s3://openfold/pdb/{pdb}_{chain}/a3m/bfd_uniclust_hits.a3m"
    local_path = os.path.join(output_dir, f"{pdb}_{chain}.a3m")

    # Download the MSA
    subprocess.run(["aws", "s3", "cp", "--no-sign-request", s3_path, local_path])


In [None]:
## ESM-2 embed both pdb1/2 seqs and concatenate embeddings

In [None]:
!pip install fair-esm

import pandas as pd
import torch
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
from transformers import T5ForConditionalGeneration, T5Tokenizer
import esm
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer

# Load ESM-2 model for embeddings
esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = alphabet.get_batch_converter()
esm_model.eval()
if torch.cuda.is_available():
    esm_model.cuda()

Collecting fair-esm
  Downloading fair_esm-2.0.0-py3-none-any.whl (93 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/93.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m93.1/93.1 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: fair-esm
Successfully installed fair-esm-2.0.0


Downloading: "https://dl.fbaipublicfiles.com/fair-esm/models/esm2_t33_650M_UR50D.pt" to /root/.cache/torch/hub/checkpoints/esm2_t33_650M_UR50D.pt
Downloading: "https://dl.fbaipublicfiles.com/fair-esm/regression/esm2_t33_650M_UR50D-contact-regression.pt" to /root/.cache/torch/hub/checkpoints/esm2_t33_650M_UR50D-contact-regression.pt


In [None]:
import torch
from torch.utils.data import Dataset
import esm

class ProteinPairsDataset(Dataset):
    def __init__(self, dataframe):
        self.sequence_data = dataframe
        self.esm_model, self.alphabet = esm.pretrained.esm2_t33_650M_UR50D()
        self.esm_model = self.esm_model.eval()
        if torch.cuda.is_available():
            self.esm_model = self.esm_model.cuda()
        self.batch_converter = self.alphabet.get_batch_converter()

    def __len__(self):
        return len(self.sequence_data)

    def __getitem__(self, idx):
        row = self.sequence_data.iloc[idx]
        header = f"{row['PDB1']}--{row['PDB2']}"
        seq1, seq2 = row['Sequence1'], row['Sequence2']
        emb1 = self.generate_esm_embedding(seq1)
        emb2 = self.generate_esm_embedding(seq2)
        concatenated_embedding = torch.cat((emb1, emb2), dim=0)
        return {header: concatenated_embedding}

    def generate_esm_embedding(self, seq):
        batch_labels, batch_strs, batch_tokens = self.batch_converter([("", seq)])
        batch_tokens = batch_tokens.to('cuda' if torch.cuda.is_available() else 'cpu')
        with torch.no_grad():
            results = self.esm_model(batch_tokens, repr_layers=[33])
        token_representations = results["representations"][33]
        seq_len = (batch_tokens != self.alphabet.padding_idx).sum(1).item()
        return token_representations[0, 1:seq_len-1].cpu()


In [None]:
protein_dataset = ProteinPairsDataset(filtered_df)

In [None]:
concatenated_embeddings = {}
for i in range(len(protein_dataset)):
    data = protein_dataset[i]
    concatenated_embeddings.update(data)

In [None]:
save_path = "/content/drive/MyDrive/flamingo-pep-gen/data_dump/distogram_pred/pred-test-pairs-26.pkl"
# torch.save(concatenated_embeddings, save_path)

In [None]:
import pickle
with open(save_path, 'wb') as file:
    pickle.dump(concatenated_embeddings, file)

In [None]:
!ls /content/drive/MyDrive/flamingo-pep-gen/data_dump/distogram_pred/

contact_prediction.ipynb  pred-test-26.pkl	  protein_pairs_dataset.pkl
metadata.csv		  pred-test-pairs-26.pkl


## downloads

In [None]:
import esm

In [None]:
!pip install biopython biotite
!pip install git+https://github.com/facebookresearch/esm.git
!apt-get install aria2

Collecting biopython
  Downloading biopython-1.82-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.1/3.1 MB[0m [31m11.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting biotite
  Downloading biotite-0.38.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (50.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m50.5/50.5 MB[0m [31m30.6 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: biopython, biotite
Successfully installed biopython-1.82 biotite-0.38.0
Collecting git+https://github.com/facebookresearch/esm.git
  Cloning https://github.com/facebookresearch/esm.git to /tmp/pip-req-build-ksneknj9
  Running command git clone --filter=blob:none --quiet https://github.com/facebookresearch/esm.git /tmp/pip-req-build-ksneknj9
  Resolved https://github.com/facebookresearch/esm.git to commit 2b369911bb5b4b0dda914521b9475cad1656b2ac
  Installing build dependencies ... [?

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
The following additional packages will be installed:
  libaria2-0 libc-ares2
The following NEW packages will be installed:
  aria2 libaria2-0 libc-ares2
0 upgraded, 3 newly installed, 0 to remove and 24 not upgraded.
Need to get 1,513 kB of archives.
After this operation, 5,441 kB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu jammy-updates/main amd64 libc-ares2 amd64 1.18.1-1ubuntu0.22.04.2 [45.0 kB]
Get:2 http://archive.ubuntu.com/ubuntu jammy/universe amd64 libaria2-0 amd64 1.36.0-1 [1,086 kB]
Get:3 http://archive.ubuntu.com/ubuntu jammy/universe amd64 aria2 amd64 1.36.0-1 [381 kB]
Fetched 1,513 kB in 1s (1,583 kB/s)
Selecting previously unselected package libc-ares2:amd64.
(Reading database ... 121658 files and directories currently installed.)
Preparing to unpack .../libc-ares2_1.18.1-1ubuntu0.22.04.2_amd64.deb ...
Unpacking libc-ares2:amd64 (1.18.1-1ubun

In [None]:
!mkdir -p /root/.cache/torch/hub/checkpoints
!aria2c --dir=/root/.cache/torch/hub/checkpoints --continue --split 8 --max-connection-per-server 8\
    https://dl.fbaipublicfiles.com/fair-esm/models/esm2_t33_650M_UR50S.pt
!aria2c --dir=/root/.cache/torch/hub/checkpoints --continue --split 8 --max-connection-per-server 8\
    https://dl.fbaipublicfiles.com/fair-esm/regression/esm2_t33_650M_UR50S-contact-regression.pt
!aria2c --dir=/root/.cache/torch/hub/checkpoints --continue --split 8 --max-connection-per-server 8\
    https://dl.fbaipublicfiles.com/fair-esm/models/esm_msa1b_t12_100M_UR50S.pt
!aria2c --dir=/root/.cache/torch/hub/checkpoints --continue --split 8 --max-connection-per-server 8\
    https://dl.fbaipublicfiles.com/fair-esm/regression/esm_msa1b_t12_100M_UR50S-contact-regression.pt


12/27 16:36:13 [[1;32mNOTICE[0m] Downloading 1 item(s)

12/27 16:36:13 [[1;31mERROR[0m] CUID#7 - Download aborted. URI=https://dl.fbaipublicfiles.com/fair-esm/models/esm2_t33_650M_UR50S.pt
Exception: [AbstractCommand.cc:351] errorCode=22 URI=https://dl.fbaipublicfiles.com/fair-esm/models/esm2_t33_650M_UR50S.pt
  -> [HttpSkipResponseCommand.cc:239] errorCode=22 The response status is not successful. status=403

12/27 16:36:13 [[1;32mNOTICE[0m] Download GID#fda8b689c2726f72 not complete: 

Download Results:
gid   |stat|avg speed  |path/URI
fda8b6|[1;31mERR[0m |       0B/s|https://dl.fbaipublicfiles.com/fair-esm/models/esm2_t33_650M_UR50S.pt

Status Legend:
(ERR):error occurred.

aria2 will resume download if the transfer is restarted.
If there are any errors, then see the log file. See '-l' option in help/man page for details.

12/27 16:36:13 [[1;32mNOTICE[0m] Downloading 1 item(s)

12/27 16:36:14 [[1;31mERROR[0m] CUID#7 - Download aborted. URI=https://dl.fbaipublicfiles.com

## Define Functions

In [None]:
!pip install biotite

Collecting biotite
  Downloading biotite-0.38.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (50.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m50.5/50.5 MB[0m [31m20.0 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: biotite
Successfully installed biotite-0.38.0


In [None]:
from typing import List, Tuple, Optional, Dict, NamedTuple, Union, Callable
import itertools
import os
import string
from pathlib import Path

import numpy as np
import torch
from scipy.spatial.distance import squareform, pdist, cdist
import matplotlib.pyplot as plt
import matplotlib as mpl
from Bio import SeqIO
import biotite.structure as bs
from biotite.structure.io.pdbx import PDBxFile, get_structure
from biotite.database import rcsb
from tqdm import tqdm
import pandas as pd

import esm

torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f51902eca30>

### Parsing alignments

In [None]:
# This is an efficient way to delete lowercase characters and insertion characters from a string
deletekeys = dict.fromkeys(string.ascii_lowercase)
deletekeys["."] = None
deletekeys["*"] = None
translation = str.maketrans(deletekeys)

def read_sequence(filename: str) -> Tuple[str, str]:
    """ Reads the first (reference) sequences from a fasta or MSA file."""
    record = next(SeqIO.parse(filename, "fasta"))
    return record.description, str(record.seq)

def remove_insertions(sequence: str) -> str:
    """ Removes any insertions into the sequence. Needed to load aligned sequences in an MSA. """
    return sequence.translate(translation)

def read_msa(filename: str) -> List[Tuple[str, str]]:
    """ Reads the sequences from an MSA file, automatically removes insertions."""
    return [(record.description, remove_insertions(str(record.seq))) for record in SeqIO.parse(filename, "fasta")]

### Converting structures to contacts

There are many ways to define a protein contact. Here we're using the definition of 8 angstroms between carbon beta atoms. Note that the position of the carbon beta is imputed from the position of the N, CA, and C atoms for each residue.

In [None]:
def extend(a, b, c, L, A, D):
    """
    input:  3 coords (a,b,c), (L)ength, (A)ngle, and (D)ihedral
    output: 4th coord
    """

    def normalize(x):
        return x / np.linalg.norm(x, ord=2, axis=-1, keepdims=True)

    bc = normalize(b - c)
    n = normalize(np.cross(b - a, bc))
    m = [bc, np.cross(n, bc), n]
    d = [L * np.cos(A), L * np.sin(A) * np.cos(D), -L * np.sin(A) * np.sin(D)]
    return c + sum([m * d for m, d in zip(m, d)])


def contacts_from_pdb(
    structure: bs.AtomArray,
    distance_threshold: float = 10,
    chain: Optional[str] = None,
) -> np.ndarray:
    mask = ~structure.hetero
    if chain is not None:
        mask &= structure.chain_id == chain

    N = structure.coord[mask & (structure.atom_name == "N")]
    CA = structure.coord[mask & (structure.atom_name == "CA")]
    C = structure.coord[mask & (structure.atom_name == "C")]

    Cbeta = extend(C, N, CA, 1.522, 1.927, -2.143)
    dist = squareform(pdist(Cbeta))

    contacts = dist < distance_threshold
    contacts = contacts.astype(np.int64)
    contacts[np.isnan(dist)] = -1
    return contacts

### Subsampling MSA

In [None]:
# Select sequences from the MSA to maximize the hamming distance
# Alternatively, can use hhfilter
def greedy_select(msa: List[Tuple[str, str]], num_seqs: int, mode: str = "max") -> List[Tuple[str, str]]:
    assert mode in ("max", "min")
    if len(msa) <= num_seqs:
        return msa

    array = np.array([list(seq) for _, seq in msa], dtype=np.bytes_).view(np.uint8)

    optfunc = np.argmax if mode == "max" else np.argmin
    all_indices = np.arange(len(msa))
    indices = [0]
    pairwise_distances = np.zeros((0, len(msa)))
    for _ in range(num_seqs - 1):
        dist = cdist(array[indices[-1:]], array, "hamming")
        pairwise_distances = np.concatenate([pairwise_distances, dist])
        shifted_distance = np.delete(pairwise_distances, indices, axis=1).mean(0)
        shifted_index = optfunc(shifted_distance)
        index = np.delete(all_indices, indices)[shifted_index]
        indices.append(index)
    indices = sorted(indices)
    return [msa[idx] for idx in indices]

### Compute contact precisions

In [None]:
def compute_precisions(
    predictions: torch.Tensor,
    targets: torch.Tensor,
    src_lengths: Optional[torch.Tensor] = None,
    minsep: int = 6,
    maxsep: Optional[int] = None,
    override_length: Optional[int] = None,  # for casp
):
    if isinstance(predictions, np.ndarray):
        predictions = torch.from_numpy(predictions)
    if isinstance(targets, np.ndarray):
        targets = torch.from_numpy(targets)
    if predictions.dim() == 2:
        predictions = predictions.unsqueeze(0)
    if targets.dim() == 2:
        targets = targets.unsqueeze(0)
    override_length = (targets[0, 0] >= 0).sum()

    # Check sizes
    if predictions.size() != targets.size():
        raise ValueError(
            f"Size mismatch. Received predictions of size {predictions.size()}, "
            f"targets of size {targets.size()}"
        )
    device = predictions.device

    batch_size, seqlen, _ = predictions.size()
    seqlen_range = torch.arange(seqlen, device=device)

    sep = seqlen_range.unsqueeze(0) - seqlen_range.unsqueeze(1)
    sep = sep.unsqueeze(0)
    valid_mask = sep >= minsep
    valid_mask = valid_mask & (targets >= 0)  # negative targets are invalid

    if maxsep is not None:
        valid_mask &= sep < maxsep

    if src_lengths is not None:
        valid = seqlen_range.unsqueeze(0) < src_lengths.unsqueeze(1)
        valid_mask &= valid.unsqueeze(1) & valid.unsqueeze(2)
    else:
        src_lengths = torch.full([batch_size], seqlen, device=device, dtype=torch.long)

    predictions = predictions.masked_fill(~valid_mask, float("-inf"))

    x_ind, y_ind = np.triu_indices(seqlen, minsep)
    predictions_upper = predictions[:, x_ind, y_ind]
    targets_upper = targets[:, x_ind, y_ind]

    topk = seqlen if override_length is None else max(seqlen, override_length)
    indices = predictions_upper.argsort(dim=-1, descending=True)[:, :topk]
    topk_targets = targets_upper[torch.arange(batch_size).unsqueeze(1), indices]
    if topk_targets.size(1) < topk:
        topk_targets = F.pad(topk_targets, [0, topk - topk_targets.size(1)])

    cumulative_dist = topk_targets.type_as(predictions).cumsum(-1)

    gather_lengths = src_lengths.unsqueeze(1)
    if override_length is not None:
        gather_lengths = override_length * torch.ones_like(
            gather_lengths, device=device
        )

    gather_indices = (
        torch.arange(0.1, 1.1, 0.1, device=device).unsqueeze(0) * gather_lengths
    ).type(torch.long) - 1

    binned_cumulative_dist = cumulative_dist.gather(1, gather_indices)
    binned_precisions = binned_cumulative_dist / (gather_indices + 1).type_as(
        binned_cumulative_dist
    )

    pl5 = binned_precisions[:, 1]
    pl2 = binned_precisions[:, 4]
    pl = binned_precisions[:, 9]
    auc = binned_precisions.mean(-1)

    return {"AUC": auc, "P@L": pl, "P@L2": pl2, "P@L5": pl5}


def evaluate_prediction(
    predictions: torch.Tensor,
    targets: torch.Tensor,
) -> Dict[str, float]:
    if isinstance(targets, np.ndarray):
        targets = torch.from_numpy(targets)
    contact_ranges = [
        ("local", 3, 6),
        ("short", 6, 12),
        ("medium", 12, 24),
        ("long", 24, None),
    ]
    metrics = {}
    targets = targets.to(predictions.device)
    for name, minsep, maxsep in contact_ranges:
        rangemetrics = compute_precisions(
            predictions,
            targets,
            minsep=minsep,
            maxsep=maxsep,
        )
        for key, val in rangemetrics.items():
            metrics[f"{name}_{key}"] = val.item()
    return metrics

### Plotting Results

In [None]:
"""Adapted from: https://github.com/rmrao/evo/blob/main/evo/visualize.py"""
def plot_contacts_and_predictions(
    predictions: Union[torch.Tensor, np.ndarray],
    contacts: Union[torch.Tensor, np.ndarray],
    ax: Optional[mpl.axes.Axes] = None,
    # artists: Optional[ContactAndPredictionArtists] = None,
    cmap: str = "Blues",
    ms: float = 1,
    title: Union[bool, str, Callable[[float], str]] = True,
    animated: bool = False,
) -> None:

    if isinstance(predictions, torch.Tensor):
        predictions = predictions.detach().cpu().numpy()
    if isinstance(contacts, torch.Tensor):
        contacts = contacts.detach().cpu().numpy()
    if ax is None:
        ax = plt.gca()

    seqlen = contacts.shape[0]
    relative_distance = np.add.outer(-np.arange(seqlen), np.arange(seqlen))
    bottom_mask = relative_distance < 0
    masked_image = np.ma.masked_where(bottom_mask, predictions)
    invalid_mask = np.abs(np.add.outer(np.arange(seqlen), -np.arange(seqlen))) < 6
    predictions = predictions.copy()
    predictions[invalid_mask] = float("-inf")

    topl_val = np.sort(predictions.reshape(-1))[-seqlen]
    pred_contacts = predictions >= topl_val
    true_positives = contacts & pred_contacts & ~bottom_mask
    false_positives = ~contacts & pred_contacts & ~bottom_mask
    other_contacts = contacts & ~pred_contacts & ~bottom_mask

    if isinstance(title, str):
        title_text: Optional[str] = title
    elif title:
        long_range_pl = compute_precisions(predictions, contacts, minsep=24)[
            "P@L"
        ].item()
        if callable(title):
            title_text = title(long_range_pl)
        else:
            title_text = f"Long Range P@L: {100 * long_range_pl:0.1f}"
    else:
        title_text = None

    img = ax.imshow(masked_image, cmap=cmap, animated=animated)
    oc = ax.plot(*np.where(other_contacts), "o", c="grey", ms=ms)[0]
    fn = ax.plot(*np.where(false_positives), "o", c="r", ms=ms)[0]
    tp = ax.plot(*np.where(true_positives), "o", c="b", ms=ms)[0]
    ti = ax.set_title(title_text) if title_text is not None else None
    # artists = ContactAndPredictionArtists(img, oc, fn, tp, ti)

    ax.axis("square")
    ax.set_xlim([0, seqlen])
    ax.set_ylim([0, seqlen])

## Predict and Visualize

### Read Data

In [None]:
# Assuming filtered_df has columns ['Header', 'PDB1', 'PDB2']
header_to_pdb_mapping = dict(zip(filtered_df['Header'], zip(filtered_df['PDB1'], filtered_df['PDB2'])))
header_to_pdb_mapping

NameError: ignored

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from Bio.PDB import PDBParser
from Bio import SeqIO

# Directories for your data
pdb_directory = "/content/drive/MyDrive/flamingo-pep-gen/pinder-dataset/2023-11/pdbs/2023-11/pdbs/"
msa_directory = "/content/drive/MyDrive/flamingo-pep-gen/data_dump/openfold_esm-26"
distogram_directory = "/content/drive/MyDrive/distograms-10A"

# Function to read PDB structure
def read_pdb_structure(pdb_id):
    pdb_file_path = os.path.join(pdb_directory, f"{pdb_id}.pdb")
    pdb_parser = PDBParser(QUIET=True)
    structure = pdb_parser.get_structure(pdb_id, pdb_file_path) if os.path.exists(pdb_file_path) else None
    return structure

# Function to read MSA
def read_msa(pdb_id):
    msa_file_path = os.path.join(msa_directory, f"{pdb_id}.a3m")
    msa = list(SeqIO.parse(msa_file_path, "fasta")) if os.path.exists(msa_file_path) else []
    return msa

# Function to read contact map
def read_contact_map(header):
    contact_map_file = os.path.join(distogram_directory, f"{header}_matrix.npy")
    contact_map = np.load(contact_map_file) if os.path.exists(contact_map_file) else None
    return contact_map


In [None]:

# Example usage for a specific header
header = "6j5a__B1_A0A287B4I0--6j5a__D1_Q95339"
pdb_id1, pdb_id2 = header_to_pdb_mapping.get(header, ("Unknown", "Unknown"))
complex_structure = read_pdb_structure(header)
msa1 = read_msa(pdb_id1)
msa2 = read_msa(pdb_id2)
contact_map = read_contact_map(header)


In [None]:
# Printing the first sequence of the first MSA, if available
if msa2:
    print(msa2[0].seq)

In [None]:
pdb_id1,pdb_id2

In [None]:
contact_map

In [None]:
complex_structure

## Profile-Profile Alignment of MSAs + Simple Concatenation

### hhsuite & uniclust30 download

In [None]:
!sudo apt-get install hhsuite

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
The following additional packages will be installed:
  hhsuite-data
The following NEW packages will be installed:
  hhsuite hhsuite-data
0 upgraded, 2 newly installed, 0 to remove and 24 not upgraded.
Need to get 17.7 MB of archives.
After this operation, 480 MB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu jammy/universe amd64 hhsuite-data all 3.3.0+ds-6 [4,470 kB]
Get:2 http://archive.ubuntu.com/ubuntu jammy/universe amd64 hhsuite amd64 3.3.0+ds-6 [13.2 MB]
Fetched 17.7 MB in 3s (5,175 kB/s)
debconf: unable to initialize frontend: Dialog
debconf: (No usable dialog-like program is installed, so the dialog based frontend cannot be used. at /usr/share/perl5/Debconf/FrontEnd/Dialog.pm line 78, <> line 2.)
debconf: falling back to frontend: Readline
debconf: unable to initialize frontend: Readline
debconf: (This frontend requires a controlling tty.)
debconf: fal

In [None]:
!mkdir -p /content/drive/MyDrive/flamingo-pep-gen/data_dump/UniRef30

In [None]:
os.chdir('/content/drive/MyDrive/flamingo-pep-gen/data_dump/UniRef30')

In [None]:
!ls

UniRef30_2023_02_hhsuite.tar.gz  UniRef30_latest.tar.gz  UniRef30_latest.tar.gz.1


In [None]:
!wget -c -P /content/drive/MyDrive/flamingo-pep-gen/data_dump/UniRef30/ https://gwdu111.gwdg.de/~compbiol/uniclust/2023_02/UniRef30_2023_02_hhsuite.tar.gz

--2023-12-28 01:23:05--  https://gwdu111.gwdg.de/~compbiol/uniclust/2023_02/UniRef30_2023_02_hhsuite.tar.gz
Resolving gwdu111.gwdg.de (gwdu111.gwdg.de)... 134.76.10.111
Connecting to gwdu111.gwdg.de (gwdu111.gwdg.de)|134.76.10.111|:443... connected.
HTTP request sent, awaiting response... 206 Partial Content
Length: 70555922504 (66G), 67862745504 (63G) remaining [application/x-gzip]
Saving to: ‘/content/drive/MyDrive/flamingo-pep-gen/data_dump/UniRef30/UniRef30_2023_02_hhsuite.tar.gz’

   UniRef30_2023_02   6%[>                   ]   4.00G   881KB/s    eta 12h 31m^C


In [None]:
# !tar -xvzf /content/drive/MyDrive/flamingo-pep-gen/data_dump/UniRef30/UniRef30_latest.tar.gz -C /content/drive/MyDrive/flamingo-pep-gen/data_dump/UniRef30/


In [None]:
import subprocess
import os
from Bio import AlignIO

def align_sequences(msa_path, output_hmm, database_path):
    """
    Aligns sequences in a given MSA file and saves the resulting alignment
    as an HMM profile. This is necessary when the original MSA files
    may not be properly aligned or have sequences of different lengths.
    Aligning them separately ensures that each MSA is correctly formatted
    and all sequences within each MSA are of the same length,
    facilitating accurate and meaningful concatenation later.
    """
    subprocess.run(['hhblits', '-i', msa_path, '-d', database_path, '-oa3m', output_hmm])

def concatenate_aligned_msas(output_path, output_file_name, hmm1, hmm2):
    """
    Concatenates two aligned MSAs (represented as HMM profiles) into a single MSA.
    This step is crucial for tasks like protein-protein interaction prediction,
    where input from two separate proteins (or protein domains) is combined
    to predict their interaction. Proper alignment and concatenation of MSAs
    ensure that the combined MSA reflects the correct sequence relationship
    between the proteins, enhancing the predictive quality of downstream models.
    """
    aligned_profiles = os.path.join(output_path, output_file_name)
    subprocess.run(['hhalign', '-i', hmm1, '-t', hmm2, '-o', aligned_profiles])

# Example Usage
output_path = '/content/drive/MyDrive/flamingo-pep-gen/data_dump/aligned_profiles/'
msa_path1 = f'/content/drive/MyDrive/flamingo-pep-gen/data_dump/openfold_esm-26/{pdb_id1}.a3m'
msa_path2 = f'/content/drive/MyDrive/flamingo-pep-gen/data_dump/openfold_esm-26/{pdb_id2}.a3m'
output_file_name = 'aligned_profile.hhr'
database_path = "/content/drive/MyDrive/flamingo-pep-gen/data_dump/UniRef30/UniRef30_2023_02_hhsuite"  # Update this path as needed

hmm1 = os.path.join(output_path, 'profile1.hmm')
hmm2 = os.path.join(output_path, 'profile2.hmm')

# Align each set of sequences separately
align_sequences(msa_path1, hmm1, database_path)
align_sequences(msa_path2, hmm2, database_path)

# Concatenate the aligned MSAs
concatenate_aligned_msas(output_path, output_file_name, hmm1, hmm2)

## ESM-2 Predictions

In [None]:
!pip install fair-esm
import esm

Collecting fair-esm
  Downloading fair_esm-2.0.0-py3-none-any.whl (93 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/93.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━[0m [32m61.4/93.1 kB[0m [31m1.7 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m93.1/93.1 kB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: fair-esm
Successfully installed fair-esm-2.0.0


In [None]:
esm2, esm2_alphabet = esm.pretrained.esm2_t33_650M_UR50D()
esm2 = esm2.eval().cuda()
esm2_batch_converter = esm2_alphabet.get_batch_converter()

In [None]:
esm2, esm2_alphabet = esm.pretrained.esm_msa1b_t12_100M_UR50S()
esm2 = esm2.eval().cuda()
esm2_batch_converter = esm2_alphabet.get_batch_converter()

Downloading: "https://dl.fbaipublicfiles.com/fair-esm/models/esm_msa1b_t12_100M_UR50S.pt" to /root/.cache/torch/hub/checkpoints/esm_msa1b_t12_100M_UR50S.pt
Downloading: "https://dl.fbaipublicfiles.com/fair-esm/regression/esm_msa1b_t12_100M_UR50S-contact-regression.pt" to /root/.cache/torch/hub/checkpoints/esm_msa1b_t12_100M_UR50S-contact-regression.pt


In [None]:
!pip install biopython

Collecting biopython
  Downloading biopython-1.82-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.1 MB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/3.1 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.1/3.1 MB[0m [31m2.0 MB/s[0m eta [36m0:00:02[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━[0m [32m2.6/3.1 MB[0m [31m37.4 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.1/3.1 MB[0m [31m33.2 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: biopython
Successfully installed biopython-1.82


In [None]:
import pandas as pd
import pickle
from Bio import SeqIO

# Load the sequence to PDB ID mapping dictionary
with open('seq_to_msa_id_dict.pkl', 'rb') as file:
    seq_to_msa_id_dict = pickle.load(file)

# Load the sequences from the fasta file
fasta_file = 'sequences_esm-26.fasta'
sequence_pairs = []

# Reading the fasta file to extract sequence pairs
with open(fasta_file, 'r') as fasta:
    for record in SeqIO.parse(fasta, 'fasta'):
        header = record.description
        sequences = str(record.seq).split(',')
        if len(sequences) == 2:
            sequence_pairs.append((header, sequences[0], sequences[1]))



In [None]:
# Creating a dataframe with columns for Header, Sequence 1, Sequence 2, PDB ID 1, and PDB ID 2
data = []

# Iterate over sequence pairs and map to PDB IDs
for header, seq1, seq2 in sequence_pairs:
    pdb1 = seq_to_msa_id_dict.get(seq1, 'Unknown')
    pdb2 = seq_to_msa_id_dict.get(seq2, 'Unknown')
    data.append([header, seq1, seq2, pdb1, pdb2])

# Create a DataFrame
df = pd.DataFrame(data, columns=['Header', 'Sequence1', 'Sequence2', 'PDB1', 'PDB2'])
df = df[(df['PDB1'] != 'Unknown') & (df['PDB2'] != 'Unknown')]

In [None]:
df

Unnamed: 0,Header,Sequence1,Sequence2,PDB1,PDB2
7,6j5a__B1_A0A287B4I0--6j5a__D1_Q95339,PFDQMTIEDLNEVFPETKLDKKKY,ASVVPLKDRRLLEVKLGELPSWILMRDFTPSGIAGAFQRGYYRYYN...,[6j54_d],[6j54_f]
10,3u8q__A1_P24627--3u8q__B1_P24627,YTRVVWCAVGPEEQKKCQQWSQQSGQNVTCATASTTDDCIVLVLKG...,LEACAF,[1sdx_A],[3tod_B]
18,1ezv__D2_P07143--1ezv__I2_P22289,MTAAEHGLHAPAYAWSHNGPFETFDHASIRRGYQVYREVCAACHSL...,SSLYKTFFKRNAVFVGTIFAGAFVFQTVFDTAITSWYENHNKGKLW...,[1ezv_D],[1ezv_I]
19,3hi4__A1_P22862--3hi4__B1_P22862,STFVAKDGTQIYFKDWGSGKPVLFSHGWPLDADMWEYQMEYLSSRG...,STFVAKDGTQIYFKDWGSGKPVLFSHGWPLDADMWEYQMEYLSSRG...,[3hi4_A],[3hi4_A]
20,2ooz__A1_P14174--2ooz__B1_P14174,PMFIVNTNVPRASVPDGFLSELTQQLAQATGKPPQYIAVHVVPDQL...,PMFIVNTNVPRASVPDGFLSELTQQLAQATGKPPQYIAVHVVPDQL...,[1ca7_A],[1ca7_A]
...,...,...,...,...,...
4659,6dw5__A1_Q9Y3Z3--6dw5__D1_Q9Y3Z3,SELDAKLNKLGVDRIAISPYKQWTRGYMEPGNIGNGYVTGLKVDAG...,SELDAKLNKLGVDRIAISPYKQWTRGYMEPGNIGNGYVTGLKVDAG...,[1ibt_A],[1ibt_A]
4662,6lth__A1_P51532--6lth__B1_O14497,XGKLEAIAQKLEAIAKKLEAIAWKLEAIAQGAGX,XGKLEAIAQKLEAIAKKLEAIAWKLEAIAQGAGX,[6q5q_A],[6q5q_A]
4688,7x1u__C1_UNDEFINED--7x1u__D1_P62871,XSLSDKDKAAVRALWSKIGKSSDAIGNDALSRMIVVYPQTKIYFSH...,VEWTDKERSIISDIFSHMDYDDIGPKALSRCLVVYPWTQRYFSGFG...,[1la6_A],[1la6_B]
4692,6nd1__B1_P14906--6nd1__E1_P33754,TVPDRDNDGIPDSLEVEGYTVDVKNKRTFLSPWISNIHEKKGLTKY...,TVPDRDNDGIPDSLEVEGYTVDVKNKRTFLSPWISNIHEKKGLTKY...,[1tzn_A],[1tzn_A]


In [None]:
import torch
import numpy as np
import pandas as pd
import os
from Bio import AlignIO
from transformers import T5ForConditionalGeneration, T5Tokenizer

# Load ESM-2 model and tokenizer
esm2_model, esm2_alphabet = esm.pretrained.esm2_t33_650M_UR50D()
esm2_model.eval()
if torch.cuda.is_available():
    esm2_model.cuda()
batch_converter = esm2_alphabet.get_batch_converter()

# Define a function to generate embeddings for a sequence
def generate_esm_embeddings(sequence):
    batch_labels, batch_strs, batch_tokens = batch_converter([("", sequence)])
    batch_tokens = batch_tokens.to('cuda' if torch.cuda.is_available() else 'cpu')
    with torch.no_grad():
        results = esm2_model(batch_tokens, repr_layers=[33])
    token_representations = results["representations"][33]
    return token_representations.squeeze(0)


In [None]:
!pip install biopython
import numpy as np
import matplotlib.pyplot as plt
from Bio.PDB import PDBParser, is_aa
import os
import glob
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from Bio.PDB import PDBParser, NeighborSearch, PDBList
from Bio.PDB.Polypeptide import is_aa



In [None]:
# Function to process concatenated sequences
def process_concatenated_sequence(seq1, seq2):
    concatenated_sequence = seq1 + seq2  # Concatenate the sequences
    print(f"Concatenated Sequence Length: {len(concatenated_sequence)}")
    print(concatenated_sequence)

    # Convert concatenated sequence to tokens
    _, _, tokens = batch_converter([("", concatenated_sequence)])
    print(tokens.shape)

    return tokens.to('cuda' if torch.cuda.is_available() else 'cpu')


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from Bio.PDB import PDBParser, is_aa

def contacts_from_pdb_file(pdb_file_path, threshold_distance=10.0):
    print(f"Processing PDB file at {pdb_file_path}")

    # Load the PDB file using a PDBParser
    pdb_parser = PDBParser(QUIET=True)
    pdb_id = pdb_file_path.split('/')[-1].split('.')[0]  # Extract pdb_id from file path
    structure = pdb_parser.get_structure(pdb_id, pdb_file_path)

    # Dictionary to store residue information
    residue_info = {}
    residue_keys = []

    # Process each chain and residue in the PDB file
    for chain in structure.get_chains():
        for residue in chain:
            if is_aa(residue, standard=True) and "CA" in residue:
                residue_id = residue.get_id()
                residue_key = (chain.id, residue_id[1])
                residue_keys.append(residue_key)
                residue_info[residue_key] = residue["CA"].get_coord()

    # Initialize distance matrix
    num_residues = len(residue_info)
    dist_matrix = np.zeros((num_residues, num_residues))

    # Fill distance matrix
    for i, key1 in enumerate(residue_keys):
        for j, key2 in enumerate(residue_keys):
            if key1 != key2:
                dist_matrix[i, j] = np.linalg.norm(residue_info[key1] - residue_info[key2])

    # Generate binary contact map from distance matrix
    contact_map = (dist_matrix < threshold_distance).astype(int)

    # # Visualization of Distance Matrix
    # plt.imshow(dist_matrix, cmap="viridis", origin="lower")
    # plt.colorbar(label="Distance (Å)")
    # plt.title(f"Residue-Residue Distogram for {pdb_id}")
    # plt.show()

    # Visualization of Binary Contact Map
    fig, ax = plt.subplots(figsize=(6, 6))
    ax.imshow(contact_map, cmap="viridis", origin="lower")
    #ax.colorbar(label="Contact (1 or 0)")
    ax.set_title(f"Binary Contact Map for {pdb_id}")
    plt.savefig(f'/content/drive/MyDrive/flamingo-pep-gen/initial-test-12-28/{pdb_id}_ground_truth_plot.png')
    plt.close(fig)


    return contact_map


In [None]:
def compute_precisions(
    predictions: torch.Tensor,
    targets: torch.Tensor,
    src_lengths: Optional[torch.Tensor] = None,
    minsep: int = 6,
    maxsep: Optional[int] = None,
    override_length: Optional[int] = None,  # for casp
):
    if isinstance(predictions, np.ndarray):
        predictions = torch.from_numpy(predictions)
    if isinstance(targets, np.ndarray):
        targets = torch.from_numpy(targets)
    if predictions.dim() == 2:
        predictions = predictions.unsqueeze(0)
    if targets.dim() == 2:
        targets = targets.unsqueeze(0)
    override_length = (targets[0, 0] >= 0).sum()

    # Check sizes
    if predictions.size() != targets.size():
        raise ValueError(
            f"Size mismatch. Received predictions of size {predictions.size()}, "
            f"targets of size {targets.size()}"
        )
    device = predictions.device

    batch_size, seqlen, _ = predictions.size()
    seqlen_range = torch.arange(seqlen, device=device)

    sep = seqlen_range.unsqueeze(0) - seqlen_range.unsqueeze(1)
    sep = sep.unsqueeze(0)
    valid_mask = sep >= minsep
    valid_mask = valid_mask & (targets >= 0)  # negative targets are invalid

    if maxsep is not None:
        valid_mask &= sep < maxsep

    if src_lengths is not None:
        valid = seqlen_range.unsqueeze(0) < src_lengths.unsqueeze(1)
        valid_mask &= valid.unsqueeze(1) & valid.unsqueeze(2)
    else:
        src_lengths = torch.full([batch_size], seqlen, device=device, dtype=torch.long)

    predictions = predictions.masked_fill(~valid_mask, float("-inf"))

    x_ind, y_ind = np.triu_indices(seqlen, minsep)
    predictions_upper = predictions[:, x_ind, y_ind]
    targets_upper = targets[:, x_ind, y_ind]

    topk = seqlen if override_length is None else max(seqlen, override_length)
    indices = predictions_upper.argsort(dim=-1, descending=True)[:, :topk]
    topk_targets = targets_upper[torch.arange(batch_size).unsqueeze(1), indices]
    if topk_targets.size(1) < topk:
        topk_targets = F.pad(topk_targets, [0, topk - topk_targets.size(1)])

    cumulative_dist = topk_targets.type_as(predictions).cumsum(-1)

    gather_lengths = src_lengths.unsqueeze(1)
    if override_length is not None:
        gather_lengths = override_length * torch.ones_like(
            gather_lengths, device=device
        )

    gather_indices = (
        torch.arange(0.1, 1.1, 0.1, device=device).unsqueeze(0) * gather_lengths
    ).type(torch.long) - 1

    binned_cumulative_dist = cumulative_dist.gather(1, gather_indices)
    binned_precisions = binned_cumulative_dist / (gather_indices + 1).type_as(
        binned_cumulative_dist
    )

    pl5 = binned_precisions[:, 1]
    pl2 = binned_precisions[:, 4]
    pl = binned_precisions[:, 9]
    auc = binned_precisions.mean(-1)

    return {"AUC": auc, "P@L": pl, "P@L2": pl2, "P@L5": pl5}


def evaluate_prediction(
    predictions: torch.Tensor,
    targets: torch.Tensor,
) -> Dict[str, float]:
    if isinstance(targets, np.ndarray):
        targets = torch.from_numpy(targets)
    contact_ranges = [
        ("local", 3, 6),
        ("short", 6, 12),
        ("medium", 12, 24),
        ("long", 24, None),
    ]
    metrics = {}
    targets = targets.to(predictions.device)
    for name, minsep, maxsep in contact_ranges:
        rangemetrics = compute_precisions(
            predictions,
            targets,
            minsep=minsep,
            maxsep=maxsep,
        )
        for key, val in rangemetrics.items():
            metrics[f"{name}_{key}"] = val.item()
    return metrics

In [None]:
"""Adapted from: https://github.com/rmrao/evo/blob/main/evo/visualize.py"""
def plot_contacts_and_predictions(
    predictions: Union[torch.Tensor, np.ndarray],
    contacts: Union[torch.Tensor, np.ndarray],
    ax: Optional[mpl.axes.Axes] = None,
    # artists: Optional[ContactAndPredictionArtists] = None,
    cmap: str = "Blues",
    ms: float = 1,
    title: Union[bool, str, Callable[[float], str]] = True,
    animated: bool = False,
) -> None:

    if isinstance(predictions, torch.Tensor):
        predictions = predictions.detach().cpu().numpy()
    if isinstance(contacts, torch.Tensor):
        contacts = contacts.detach().cpu().numpy()
    if ax is None:
        ax = plt.gca()

    seqlen = contacts.shape[0]
    relative_distance = np.add.outer(-np.arange(seqlen), np.arange(seqlen))
    bottom_mask = relative_distance < 0
    masked_image = np.ma.masked_where(bottom_mask, predictions)
    invalid_mask = np.abs(np.add.outer(np.arange(seqlen), -np.arange(seqlen))) < 6
    predictions = predictions.copy()
    predictions[invalid_mask] = float("-inf")

    topl_val = np.sort(predictions.reshape(-1))[-seqlen]
    pred_contacts = predictions >= topl_val
    true_positives = contacts & pred_contacts & ~bottom_mask
    false_positives = ~contacts & pred_contacts & ~bottom_mask
    other_contacts = contacts & ~pred_contacts & ~bottom_mask

    if isinstance(title, str):
        title_text: Optional[str] = title
    elif title:
        long_range_pl = compute_precisions(predictions, contacts, minsep=24)[
            "P@L"
        ].item()
        if callable(title):
            title_text = title(long_range_pl)
        else:
            title_text = f"Long Range P@L: {100 * long_range_pl:0.1f}"
    else:
        title_text = None

    img = ax.imshow(masked_image, cmap=cmap, animated=animated)
    oc = ax.plot(*np.where(other_contacts), "o", c="grey", ms=ms)[0]
    fn = ax.plot(*np.where(false_positives), "o", c="r", ms=ms)[0]
    tp = ax.plot(*np.where(true_positives), "o", c="b", ms=ms)[0]
    ti = ax.set_title(title_text) if title_text is not None else None
    # artists = ContactAndPredictionArtists(img, oc, fn, tp, ti)

    ax.axis("square")
    ax.set_xlim([0, seqlen])
    ax.set_ylim([0, seqlen])

In [None]:
!ls /content/drive/MyDrive/flamingo-pep-gen/initial-test-12-28/images-1000set

In [None]:
import torch.nn.functional as F


In [None]:
# Predict contacts and evaluate
esm2_predictions = {}
esm2_results = []
mismatch_errors = []

for idx, row in df.iterrows():
    header = row['Header']
    pdb_file_path = f'/content/drive/MyDrive/flamingo-pep-gen/pinder-dataset/2023-11/pdbs/2023-11/pdbs/{header}.pdb'

    # Generate contact map from PDB file
    contact_map = contacts_from_pdb_file(pdb_file_path)

    # Process sequences and predict contacts
    seq1, seq2 = row['Sequence1'], row['Sequence2']
    concatenated_tokens = process_concatenated_sequence(seq1, seq2)

    with torch.no_grad():
        esm2_predictions[header] = esm2_model.predict_contacts(concatenated_tokens)[0].cpu()

    try:
        # Evaluate the prediction
        metrics = {"id": header, "model": "ESM-2 (Unsupervised)"}
        metrics.update(evaluate_prediction(esm2_predictions[header], torch.tensor(contact_map)))
        esm2_results.append(metrics)

        # Plotting results
        fig, ax = plt.subplots(figsize=(6, 6))
        prediction = esm2_predictions[header]
        plot_contacts_and_predictions(prediction, contact_map, ax=ax, title=lambda prec: f"{header}: Long Range P@L: {100 * prec:0.1f}")
        plt.savefig(f'/content/drive/MyDrive/flamingo-pep-gen/initial-test-12-28/images-test1-msatransform/{header}_pred_plot.png')
        plt.close(fig)
    except ValueError as e:
        print(f"Size mismatch error for {header}: {e}")
        mismatch_errors.append(header)

# Convert results to a DataFrame
esm2_results_df = pd.DataFrame(esm2_results)
print(esm2_results_df)

# Print or save the headers with mismatch errors
print("Headers with size mismatch errors:", mismatch_errors)


Processing PDB file at /content/drive/MyDrive/flamingo-pep-gen/pinder-dataset/2023-11/pdbs/2023-11/pdbs/6j5a__B1_A0A287B4I0--6j5a__D1_Q95339.pdb
Concatenated Sequence Length: 111
PFDQMTIEDLNEVFPETKLDKKKYASVVPLKDRRLLEVKLGELPSWILMRDFTPSGIAGAFQRGYYRYYNKYVNVKKGSVAGLSMVLAAYVVFNYCRSYKELKHERLRKYH
torch.Size([1, 113])
Processing PDB file at /content/drive/MyDrive/flamingo-pep-gen/pinder-dataset/2023-11/pdbs/2023-11/pdbs/3u8q__A1_P24627--3u8q__B1_P24627.pdb
Concatenated Sequence Length: 341
YTRVVWCAVGPEEQKKCQQWSQQSGQNVTCATASTTDDCIVLVLKGEADALNLDGGYIYTAGKCGLVPVLAENRKSSKHSSLDCVLRPTEGYLAVAVVKKANEGLTWNSLKDKKSCHTAVDRTAGWNIPMGLIVNQTGSCAFDEFFSQSCAPGADPKSRLCALCAGDDQGLDKCVPNSKEKYYGYTGAFRCLAEDVGDVAFVKNDTVWENTNGESTADWAKNLKREDFRLLCLDGTRKPVTEAQSCHLAVAPNHAVVSRSDRAAHVEQVLLHQQALFGKNGKNCPDKFCLFKSETKNLLFNDNTECLAKLGGRPTYEEYLGTEYVTAIANLKKCSLEACAF
torch.Size([1, 343])
Processing PDB file at /content/drive/MyDrive/flamingo-pep-gen/pinder-dataset/2023-11/pdbs/2023-11/pdbs/1ezv__D2_P07143--1ezv__I2_P22289.pdb
Concatena

KeyboardInterrupt: ignored

In [None]:
import os
import shutil

source_dir = '/content/drive/MyDrive/flamingo-pep-gen/initial-test-12-28/'
destination_dir = '/content/drive/MyDrive/flamingo-pep-gen/initial-test-12-28/images-test1-msatransform/'

# Create the destination directory if it doesn't exist
os.makedirs(destination_dir, exist_ok=True)

# Iterate through the files in the source directory
for filename in os.listdir(source_dir):
    if 'ground' in filename and filename.endswith('.png'):
        source_file = os.path.join(source_dir, filename)
        destination_file = os.path.join(destination_dir, filename)

        # Move the file
        shutil.move(source_file, destination_file)
        print(f"Moved {filename}")


Moved 6j5a__B1_A0A287B4I0--6j5a__D1_Q95339_ground_truth_plot.png
Moved 3u8q__A1_P24627--3u8q__B1_P24627_ground_truth_plot.png
Moved 1ezv__D2_P07143--1ezv__I2_P22289_ground_truth_plot.png
Moved 3hi4__A1_P22862--3hi4__B1_P22862_ground_truth_plot.png
Moved 2ooz__A1_P14174--2ooz__B1_P14174_ground_truth_plot.png
Moved 3ju9__A2_P55915--3ju9__A4_P55915_ground_truth_plot.png
Moved 6e9v__O1_UNDEFINED--6e9v__S1_UNDEFINED_ground_truth_plot.png
Moved 2c5c__C1_P69178--2c5c__D1_P69178_ground_truth_plot.png
Moved 5yb2__B1_P04578--5yb2__E1_UNDEFINED_ground_truth_plot.png
Moved 7rrp__K1_P02794--7rrp__X1_P02794_ground_truth_plot.png
Moved 3vst__C1_A2ICH1--3vst__D1_A2ICH1_ground_truth_plot.png
Moved 2a7a__A3_P81461--2a7a__A4_P81461_ground_truth_plot.png
Moved 7azf__A1_P0A988--7azf__E1_UNDEFINED_ground_truth_plot.png
Moved 5gvl__A1_A5K8L9--5gvl__B1_A5K8L9_ground_truth_plot.png
Moved 3k2b__A1_P25856--3k2b__B1_P25856_ground_truth_plot.png
Moved 6p4a__A1_A0A0E4B213--6p4a__C1_P00698_ground_truth_plot.png
Move

In [None]:
# Convert results to a DataFrame
esm2_results_df = pd.DataFrame(esm2_results)
esm2_results_df

Unnamed: 0,id,model,local_AUC,local_P@L,local_P@L2,local_P@L5,short_AUC,short_P@L,short_P@L2,short_P@L5,medium_AUC,medium_P@L,medium_P@L2,medium_P@L5,long_AUC,long_P@L,long_P@L2,long_P@L5
0,6j5a__B1_A0A287B4I0--6j5a__D1_Q95339,ESM-2 (Unsupervised),0.876992,0.828829,0.872727,0.909091,0.115978,0.099099,0.145455,0.090909,0.065866,0.063063,0.090909,0.045455,0.009952,0.009009,0.018182,0.0
1,3u8q__A1_P24627--3u8q__B1_P24627,ESM-2 (Unsupervised),0.940748,0.900293,0.935294,0.970588,0.754179,0.609971,0.747059,0.823529,0.656982,0.519062,0.652941,0.808824,0.772352,0.630499,0.764706,0.911765
2,1ezv__D2_P07143--1ezv__I2_P22289,ESM-2 (Unsupervised),0.890004,0.856667,0.873333,0.916667,0.520314,0.36,0.493333,0.716667,0.382832,0.223333,0.353333,0.566667,0.447155,0.293333,0.413333,0.666667
3,3hi4__A1_P22862--3hi4__B1_P22862,ESM-2 (Unsupervised),0.950083,0.880074,0.97786,0.972222,0.646348,0.439114,0.649446,0.861111,0.530419,0.343173,0.520295,0.703704,0.946177,0.876384,0.96679,0.990741
4,2ooz__A1_P14174--2ooz__B1_P14174,ESM-2 (Unsupervised),0.917497,0.833333,0.912281,1.0,0.522851,0.372807,0.482456,0.711111,0.445964,0.302632,0.368421,0.577778,0.679537,0.583333,0.684211,0.8
5,3ju9__A2_P55915--3ju9__A4_P55915,ESM-2 (Unsupervised),0.66802,0.552743,0.64135,0.776596,0.653013,0.472574,0.658228,0.808511,0.689607,0.516878,0.691983,0.851064,0.260797,0.267932,0.291139,0.265957
6,6e9v__O1_UNDEFINED--6e9v__S1_UNDEFINED,ESM-2 (Unsupervised),0.963784,0.938389,0.971564,0.97619,0.567847,0.412322,0.50237,0.77381,0.846364,0.665877,0.867299,1.0,0.45481,0.438389,0.445498,0.47619
7,2c5c__C1_P69178--2c5c__D1_P69178,ESM-2 (Unsupervised),0.86397,0.702899,0.913043,0.962963,0.862374,0.702899,0.884058,1.0,0.875486,0.73913,0.869565,1.0,0.562659,0.485507,0.492754,0.703704
8,5yb2__B1_P04578--5yb2__E1_UNDEFINED,ESM-2 (Unsupervised),0.981964,0.966102,0.965517,1.0,0.271854,0.254237,0.275862,0.272727,0.0,0.0,0.0,0.0,0.016715,0.033898,0.0,0.0
9,7rrp__K1_P02794--7rrp__X1_P02794,ESM-2 (Unsupervised),0.929543,0.921512,0.936047,0.926471,0.411817,0.372093,0.366279,0.455882,0.464822,0.296512,0.418605,0.661765,0.73946,0.552326,0.773256,0.882353


In [None]:
esm2_results_df.to_csv('/content/drive/MyDrive/flamingo-pep-gen/initial-test-12-28/results_csv-test1-msatransform/msa-results.csv',index=False)

In [None]:
# Calculate average AUCs
avg_local_auc = esm2_results_df["local_P@L5"].mean()
avg_short_auc = esm2_results_df["short_P@L5"].mean()
avg_medium_auc = esm2_results_df["medium_P@L5"].mean()
avg_long_auc = esm2_results_df["long_P@L5"].mean()

avg_local_auc, avg_short_auc, avg_medium_auc, avg_long_auc

(0.8843629768020228, 0.62484718221975, 0.5568621333100294, 0.544931571734579)

In [None]:
len(mismatch_errors)

29

In [None]:
esm2.predict_contacts

<bound method ESM2.predict_contacts of ESM2(
  (embed_tokens): Embedding(33, 1280, padding_idx=1)
  (layers): ModuleList(
    (0-32): 33 x TransformerLayer(
      (self_attn): MultiheadAttention(
        (k_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (q_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (out_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (rot_emb): RotaryEmbedding()
      )
      (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      (fc1): Linear(in_features=1280, out_features=5120, bias=True)
      (fc2): Linear(in_features=5120, out_features=1280, bias=True)
      (final_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
    )
  )
  (contact_head): ContactPredictionHead(
    (regression): Linear(in_features=660, out_features=1, bias=True)
    (activation): Sigmoid()
  )
  (emb_laye