In [1]:
%cd /fs/pool/pool-marsot/tankbind_philip/TankBind/tankbind/

/fs/gpfs41/lv11/fileset01/pool/pool-marsot/tankbind_philip/TankBind/tankbind


In [43]:
import os
import sys
import re
sys.path.append('/fs/pool/pool-marsot')
import pandas as pd
import torch
from esm import pretrained, FastaBatchedDataset
from Bio.PDB import PDBParser

import os
import pickle
from multiprocessing import Pool
import pickle
three_to_one = {'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E', 'PHE': 'F', 'GLY': 'G', 'HIS': 'H', 
                'ILE': 'I', 'LYS': 'K', 'LEU': 'L', 'MET': 'M', 'ASN': 'N', 'PRO': 'P', 'GLN': 'Q', 
                'ARG': 'R', 'SER': 'S', 'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y'}

def extract_protein_names(folder_path):
    protein_names = []
    filepaths = []
    
    # Regex pattern to match the protein name
    pattern = r'(\w+)_protein\.pdb$'
    
    # Iterate through all files in the folder
    for filename in os.listdir(folder_path):
        if filename.endswith('_protein.pdb'):
            filepath = os.path.join(folder_path, filename)
            match = re.search(pattern, filename)
            if match:
                protein_name = match.group(1)
                protein_names.append(protein_name)
                filepaths.append(filepath)
    
    return protein_names, filepaths


def get_sequences_from_pdbfile(file_path):

    # Parse the PDB file to extract sequences
    biopython_parser = PDBParser()
    try:
        structure = biopython_parser.get_structure("random_id", file_path)
        structure = structure[0]
        sequence = None
        for i, chain in enumerate(structure):
            seq = ""
            for residue in chain:
                if residue.get_resname() == "HOH":
                    continue
                c_alpha, n, c, o = None, None, None, None
                for atom in residue:
                    if atom.name == "CA":
                        c_alpha = list(atom.get_vector())
                    if atom.name == "N":
                        n = list(atom.get_vector())
                    if atom.name == "C":
                        c = list(atom.get_vector())
                    if atom.name == "O":
                        o = list(atom.get_vector())
                if c_alpha is not None and n is not None and c is not None and o is not None:  # only append residue if it is an amino acid AND has an oxygen, like for TankBind https://github.com/luwei0917/TankBind/blob/ff85f511db11d7a3e648d2e01cd6fdb4f9823483/tankbind/feature_utils.py#L204
                    # Note that this condition can be false for the last residue of a chain, where the oxygen is not given in the PDB file although it should be there
                    try:
                        seq += three_to_one[residue.get_resname()]
                    except Exception:
                        #seq += "-"
                        print("encountered unknown AA: ", residue.get_resname(), " in the complex. Replacing it with a dash - .")
            
            if sequence is None:
                sequence = seq
            else:
                sequence += ":" + seq
    except:
        raise Exception(f"Error parsing file {file_path}")
    return sequence
def print_cuda_memory_usage():
    print(f"Allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
    print(f"Cached: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")
    print(f"Free: {(torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_reserved()) / 1024**2:.2f} MB")
    print(f"Total: {torch.cuda.get_device_properties(0).total_memory / 1024**2:.2f} MB")

def create_ESM_embeddings(labels, sequences, model_dim="650m"):
    """
    Parameters
    ----------
    labels : list
        List of labels.
    sequences : list
        List of sequences.
    Returns
    -------
    lm_embedding : dict[str: torch.Tensor]
        List of ESM embeddings, indexed by label.
    """
    if model_dim == "650m":
        model_location = "esm2_t33_650M_UR50D"
        toks_per_batch = 8192
    elif model_dim == "15B":
        model_location = "esm2_t48_15B_UR50D"
        toks_per_batch = 4096
    model, alphabet = pretrained.load_model_and_alphabet(model_location)
    model.eval()
    if torch.cuda.is_available():
        model = model.cuda()

    repr_layers = [33]
    truncation_seq_length = 1000000000
    global dataset
    global sequences_2
    global labels_2
    labels_2 = labels
    sequences_2 = sequences

    dataset = FastaBatchedDataset(labels, sequences)
    
    batches = dataset.get_batch_indices(toks_per_batch, extra_toks_per_seq=1)
    data_loader = torch.utils.data.DataLoader(
        dataset,
        collate_fn=alphabet.get_batch_converter(truncation_seq_length),
        batch_sampler=batches,
    )

    assert all(-(model.num_layers + 1) <= i <= model.num_layers for i in repr_layers)
    repr_layers = [
        (i + model.num_layers + 1) % (model.num_layers + 1) for i in repr_layers
    ]
    embeddings = {}

    with torch.no_grad():
        for batch_idx, (labels, strs, toks) in enumerate(data_loader):
            print(
                f"Processing {batch_idx + 1} of {len(batches)} batches ({toks.size(0)} sequences)"
            )
            if torch.cuda.is_available():
                toks = toks.to(device="cuda", non_blocking=True)

            out = model(toks, repr_layers=repr_layers, return_contacts=False)
            print_cuda_memory_usage()
            representations = {
                layer: t.to(device="cpu") for layer, t in out["representations"].items()
            }

            for i, label in enumerate(labels):
                truncate_len = min(truncation_seq_length, len(strs[i]))
                embeddings[label] = representations[33][i, 1 : truncate_len + 1].clone()
    return embeddings

def get_all_ESM_embeddings(protein_paths, protein_names, model="650m", save_dir="/fs/pool/pool-marsot/tankbind_philip/TankBind/data/protein_remove_extra_chains_10A"):

    if not os.path.exists(f"{save_dir}/sequences_dict.pkl"):
        sequences_dict = {protein_names[i]: get_sequences_from_pdbfile(protein_paths[i]) for i in range(len(protein_paths))}
        labels_cleaned, sequences_cleaned = [], []
        for name, sequence in sequences_dict.items():
            s = sequence.split(':')
            sequences_cleaned.extend(s)
            labels_cleaned.extend([name + '_chain_' + str(j) for j in range(len(s))])
        with open(f"{save_dir}/sequences_dict.pkl", "wb") as file:
            pickle.dump(sequences_cleaned, file)
        with open(f"{save_dir}/labels_cleaned.pkl", "wb") as file:
            pickle.dump(labels_cleaned, file)
        with open(f"{save_dir}/protein_names.pkl", "wb") as file:
            pickle.dump(protein_names, file)

    else: 
        with open(f"{save_dir}/sequences_dict.pkl", "rb") as file:
            sequences_cleaned = pickle.load(file)
        with open(f"{save_dir}/labels_cleaned.pkl", "rb") as file:
            labels_cleaned = pickle.load(file)
        with open(f"{save_dir}/protein_names.pkl", "rb") as file:
            protein_names = pickle.load(file)
        with open(f"{save_dir}/sequences_dict.pkl", "rb") as file:
            sequences_dict = pickle.load(file)

    if os.path.exists(f"{save_dir}/esm_embeddings_{model}_intermediate.pkl") and False:
        with open(f"{save_dir}/esm_embeddings_{model}_intermediate.pkl", "rb") as file:
            results = pickle.load(file)
    else:
        results = create_ESM_embeddings(labels_cleaned, sequences_cleaned, model)
        with open(f"{save_dir}/esm_embeddings_{model}_intermediate.pkl", "wb") as file:
            pickle.dump(results, file)
    return results


def get_pdbbind_ESM_embeddings(folder_path="/fs/pool/pool-marsot/tankbind_philip/TankBind/data/protein_remove_extra_chains_10A", model="650m"):
    import time
    start = time.time()
    protein_names, protein_paths = extract_protein_names(folder_path)
    result = get_all_ESM_embeddings(protein_paths, protein_names, model=model)
    end = time.time()
    print(f"Time taken: {end - start}")

    with open(f"{folder_path}/esm_embeddings_{model}.pkl", "wb") as file:
        pickle.dump(result, file)

    return result

if __name__ == "__main__":
    embeddings = get_pdbbind_ESM_embeddings()

Processing 1 of 888 batches (178 sequences)
Allocated: 2624.49 MB
Cached: 7338.00 MB
Free: 12758.00 MB
Total: 20096.00 MB
Processing 2 of 888 batches (143 sequences)
Allocated: 2624.30 MB
Cached: 7338.00 MB
Free: 12758.00 MB
Total: 20096.00 MB
Processing 3 of 888 batches (126 sequences)
Allocated: 2624.53 MB
Cached: 7338.00 MB
Free: 12758.00 MB
Total: 20096.00 MB
Processing 4 of 888 batches (109 sequences)
Allocated: 2624.54 MB
Cached: 7338.00 MB
Free: 12758.00 MB
Total: 20096.00 MB
Processing 5 of 888 batches (98 sequences)
Allocated: 2624.41 MB
Cached: 7338.00 MB
Free: 12758.00 MB
Total: 20096.00 MB
Processing 6 of 888 batches (94 sequences)
Allocated: 2624.67 MB
Cached: 7338.00 MB
Free: 12758.00 MB
Total: 20096.00 MB
Processing 7 of 888 batches (91 sequences)
Allocated: 2624.78 MB
Cached: 7338.00 MB
Free: 12758.00 MB
Total: 20096.00 MB
Processing 8 of 888 batches (88 sequences)
Allocated: 2624.77 MB
Cached: 7338.00 MB
Free: 12758.00 MB
Total: 20096.00 MB
Processing 9 of 888 batches 

In [44]:
1

1

In [45]:
# Open the intermediate esm embeddings
with open("/fs/pool/pool-marsot/tankbind_philip/TankBind/data/protein_remove_extra_chains_10A/esm_embeddings_650m_intermediate.pkl", "rb") as file:
    results = pickle.load(file)

In [16]:
# Open the protein names
with open("/fs/pool/pool-marsot/tankbind_philip/TankBind/data/protein_remove_extra_chains_10A/protein_names.pkl", "rb") as file:
    protein_names = pickle.load(file)

In [2]:
with open("/fs/pool/pool-marsot/tankbind_philip/TankBind/data/protein_remove_extra_chains_10A/sequences_dict.pkl", "rb") as file:
    sequences_cleaned = pickle.load(file)
with open("/fs/pool/pool-marsot/tankbind_philip/TankBind/data/protein_remove_extra_chains_10A/labels_cleaned.pkl", "rb") as file:
    labels_cleaned = pickle.load(file)

In [46]:
from collections import defaultdict
# Create a dictionary to store the full sequences
full_sequences = defaultdict(str)
def default_tensor():
    return torch.zeros(0, 1280, dtype=torch.float32)
full_embeddings = defaultdict(default_tensor)
# Iterate through the labels and sequences
for label, sequence in zip(labels_cleaned, sequences_cleaned):
    # Extract the protein name (assuming it's the part before the chain identifier)
    protein_name = label.split('_')[0]  # Adjust this split if your naming convention is different
    
    # Concatenate the sequence to the existing sequence for this protein
    full_sequences[protein_name] += sequence
    full_embeddings[protein_name] = torch.cat([full_embeddings[protein_name], results[label]], dim=0)
# Convert defaultdict to regular dict
full_sequences = dict(full_sequences)

# Print some statistics
print(f"Total number of proteins: {len(full_sequences)}")
print("Sample of full sequences:")
for protein, sequence in list(full_sequences.items())[:5]:  # Print first 5 as a sample
    print(f"{protein}: {sequence[:50]}...")  # Print first 50 characters of each sequence

# Optionally, save the full sequences dictionary
with open("/fs/pool/pool-marsot/tankbind_philip/TankBind/data/protein_remove_extra_chains_10A/full_sequences.pkl", "wb") as file:
    pickle.dump(full_sequences, file)
with open("/fs/pool/pool-marsot/tankbind_philip/TankBind/data/protein_remove_extra_chains_10A/full_embeddings.pkl", "wb") as file:
    pickle.dump(full_embeddings, file)

Total number of proteins: 19421
Sample of full sequences:
3exo: GSFVEMVDNLRGKSGQGYYVEMTVGSPPQTLNILVDTGSSNFAVGAAPHP...
3v4v: YNVDTESALLYQGPHNTLFGYSVVLHSHGANRWLLVGAPTANWLANASVI...
3tjc: QFEERHLKFLQQLGKGNFGSVEMCRYDPLQDNTGEVVAVKKLQHSTEEHL...
1duv: SGFYHKHFLKLLDFTPAELNSLLQLAAKLKADKKSGKEEAKLTGKNIALI...
2xii: EIPLKYGATNEGKRQDPAMQKFRDNRLGAFIHWGLYAIPGGEWNGKVYGG...


In [33]:
for label in labels_cleaned:
    if label.startswith("1zyr"):
        print(label)

1zyr_chain_0
1zyr_chain_1


In [36]:
len(sequences_cleaned[labels_cleaned.index("1zyr_chain_0")]), len(sequences_cleaned[labels_cleaned.index("1zyr_chain_1")])

(1119, 1392)

In [42]:
results["1zyr_chain_0"].shape, results["1zyr_chain_1"].shape

(torch.Size([1022, 1280]), torch.Size([1022, 1280]))

In [40]:
for c in sequences_cleaned[labels_cleaned.index("1zyr_chain_1")]:
    if c == "-":
        print("found dash")

In [26]:
len(full_sequences["1zyr"]), len(full_embeddings["1zyr"])

(2511, 2044)

In [47]:
import torch
import pickle

# Load the full sequences dictionary (if not already in memory)
with open("/fs/pool/pool-marsot/tankbind_philip/TankBind/data/protein_remove_extra_chains_10A/full_sequences.pkl", "rb") as file:
    full_sequences = pickle.load(file)


with open("/fs/pool/pool-marsot/tankbind_philip/TankBind/data/protein_remove_extra_chains_10A/full_embeddings.pkl", "rb") as file:
    full_embeddings = pickle.load(file)

# New dictionary to store the cleaned embeddings
cleaned_embeddings = {}

for protein, embedding in full_embeddings.items():
    if protein not in full_sequences:
        print(f"Warning: Protein {protein} not found in full_sequences. Skipping.")
        continue

    sequence = full_sequences[protein]
    
    # Create a boolean mask: True for known residues, False for unknown
    mask = torch.tensor([res != '-' for res in sequence])
    try:
        # Apply the mask to the embedding tensor
        cleaned_embedding = embedding[mask]
    except:
        raise Exception(f"Error cleaning embedding for protein {protein}")
    # Store the cleaned embedding
    cleaned_embeddings[protein] = cleaned_embedding

    # Print some statistics
    print(f"Protein: {protein}")
    print(f"Original embedding shape: {embedding.shape}")
    print(f"Cleaned embedding shape: {cleaned_embedding.shape}")
    print(f"Number of unknown residues removed: {embedding.shape[0] - cleaned_embedding.shape[0]}")
    print("---")

# Optionally, save the cleaned embeddings dictionary
with open("/fs/pool/pool-marsot/tankbind_philip/TankBind/data/protein_remove_extra_chains_10A/cleaned_embeddings.pkl", "wb") as file:
    pickle.dump(cleaned_embeddings, file)

print("Cleaning complete. Cleaned embeddings saved to 'cleaned_embeddings.pkl'.")

Protein: 3exo
Original embedding shape: torch.Size([373, 1280])
Cleaned embedding shape: torch.Size([373, 1280])
Number of unknown residues removed: 0
---
Protein: 3v4v
Original embedding shape: torch.Size([955, 1280])
Cleaned embedding shape: torch.Size([955, 1280])
Number of unknown residues removed: 0
---
Protein: 3tjc
Original embedding shape: torch.Size([284, 1280])
Cleaned embedding shape: torch.Size([284, 1280])
Number of unknown residues removed: 0
---
Protein: 1duv
Original embedding shape: torch.Size([666, 1280])
Cleaned embedding shape: torch.Size([666, 1280])
Number of unknown residues removed: 0
---
Protein: 2xii
Original embedding shape: torch.Size([438, 1280])
Cleaned embedding shape: torch.Size([438, 1280])
Number of unknown residues removed: 0
---
Protein: 6ce2
Original embedding shape: torch.Size([242, 1280])
Cleaned embedding shape: torch.Size([242, 1280])
Number of unknown residues removed: 0
---
Protein: 2uup
Original embedding shape: torch.Size([438, 1280])
Cleane

In [8]:
with open("/fs/pool/pool-marsot/pdbbind/pdbbind2020/dataset/processed/protein.pt", "rb") as file:
    protein_dict = torch.load(file)

In [48]:
for key in protein_dict.keys():
    shape = protein_dict[key][0].shape[0]
    if not shape == cleaned_embeddings[key].shape[0]:
        print(key, shape, cleaned_embeddings[key].shape[0])

In [1]:
import pickle
with open("/fs/pool/pool-marsot/tankbind_philip/TankBind/data/protein_remove_extra_chains_10A/cleaned_embeddings.pkl", "rb") as file:
    cleaned_embeddings = pickle.load(file)

In [2]:
cleaned_embeddings["4bh4"]

tensor([[ 0.1009,  0.1309, -0.2387,  ..., -0.0376,  0.0712, -0.0861],
        [ 0.1740,  0.1842,  0.0625,  ..., -0.1398, -0.0275,  0.1693],
        [ 0.0706,  0.1059, -0.0583,  ..., -0.1095,  0.1373,  0.0745],
        ...,
        [ 0.0313,  0.0616,  0.0272,  ..., -0.0495, -0.1337, -0.1360],
        [ 0.1016,  0.0317,  0.0656,  ..., -0.0969, -0.1341, -0.0957],
        [-0.0016,  0.0575,  0.1318,  ..., -0.1403,  0.0348, -0.0027]])

In [2]:
with open("/fs/pool/pool-marsot/tankbind_philip/TankBind/packages/EquiBind/data/timesplit_test", "r") as file:
    timesplit_test = file.read().splitlines()
with open("/fs/pool/pool-marsot/tankbind_philip/TankBind/packages/EquiBind/data/timesplit_no_lig_overlap_val", "r") as file:
    timesplit_val = file.read().splitlines()

In [37]:
"3fqa" in timesplit_val

False

In [38]:
protein_names = list(cleaned_embeddings.keys())
timesplit_train = list(set(protein_names) - set(timesplit_test) - set(timesplit_val))

In [39]:
"4bh4" in timesplit_test, "4bh4" in timesplit_train, "4bh4" in timesplit_val

(False, False, True)

In [40]:
"3fqa" in timesplit_test, "3fqa" in timesplit_train, "3fqa" in timesplit_val

(False, True, False)

In [53]:
"1zsb" in timesplit_test, "1zsb" in timesplit_train, "1zsb" in timesplit_val

(False, False, True)

In [56]:
import numpy as np
train_index = d.query("group =='train'").index.values
train_after_warmup = new_dataset[train_index]

In [63]:
train_index.dtype

dtype('int64')

In [72]:
"1zsb" in d.loc[train_index]["protein_name"].unique()

False

In [1]:
train_after_warmup

NameError: name 'train_after_warmup' is not defined

In [3]:
esm_embeddings_train = cleaned_embeddings
esm_embeddings_val = {key: cleaned_embeddings[key] for key in timesplit_val}
esm_embeddings_test = {key: cleaned_embeddings[key] for key in timesplit_test}

In [44]:
esm_embeddings_val.keys().__len__()

968

In [14]:
%cd /fs/pool/pool-marsot/

/fs/gpfs41/lv11/fileset01/pool/pool-marsot


In [16]:
import sys; sys.path.append("/fs/pool/pool-marsot/tankbind_philip/TankBind/tankbind/")
from tankbind_philip.TankBind.tankbind.data import TankBindDataSet
add_noise_to_com=None
use_esm_embeddings=False
new_dataset = TankBindDataSet("/fs/pool/pool-marsot/pdbbind/pdbbind2020/dataset", add_noise_to_com=add_noise_to_com, use_esm_embeddings=use_esm_embeddings)

['/fs/pool/pool-marsot/pdbbind/pdbbind2020/dataset/processed/data.pt', '/fs/pool/pool-marsot/pdbbind/pdbbind2020/dataset/processed/protein.pt', '/fs/pool/pool-marsot/pdbbind/pdbbind2020/dataset/processed/compound.pt', '/fs/pool/pool-marsot/pdbbind/pdbbind2020/dataset/processed/esm_embeddings.pt']


In [17]:
d = new_dataset.data

In [18]:
train_index = d.query("group =='train'").index.values

In [None]:
d[d["protein_name"] == "1zsb"]

In [30]:
"1zsb" in d["protein_name"].unique()

True

In [19]:
train_after_warm_up = new_dataset[train_index]

In [27]:
from tqdm.notebook import tqdm

num_train = 0
num_test = 0
num_val = 0

with tqdm(train_after_warm_up, total=len(train_after_warm_up)) as pbar:
    for item in pbar:
        if item.group == "train":
            num_train += 1
            if item.pdb in timesplit_val or item.pdb in timesplit_test:
                raise Exception(f"Protein {item.pdb} is in both train and val/test sets")
        elif item.group == "test":
            num_test += 1
            if item.pdb not in timesplit_test:
                raise Exception(f"Protein {item.pdb} is in test set but not in timesplit_test")
        elif item.group == "val":
            num_val += 1
            if item.pdb not in timesplit_val:
                raise Exception(f"Protein {item.pdb} is in val set but not in timesplit_val")
        
        # Update the progress bar description with current values
        pbar.set_description(f"Train: {num_train}, Test: {num_test}, Val: {num_val}")

  0%|          | 0/153325 [00:00<?, ?it/s]

  fract_of_native_contact = (data.y.numpy() > 0).sum() / float(line['native_num_contact'])


KeyboardInterrupt: 

In [None]:
"3fqa" in d[d["group"]=="train"]["protein_name"].unique()

In [49]:
train_names = d.query("group == 'train'")["protein_name"].unique()

In [None]:
"1zsb" in d[d["group"]=="train"]["protein_name"].unique()

In [4]:
import torch
with open("/fs/pool/pool-marsot/pdbbind/pdbbind2020/dataset/processed/esm_embeddings.pt", "wb") as file:
    torch.save(cleaned_embeddings, file)
with open("/fs/pool/pool-marsot/tankbind_philip/TankBind/dataset/test_dataset/processed/esm_embeddings.pt", "wb") as file:
    torch.save(esm_embeddings_test, file)
with open("/fs/pool/pool-marsot/tankbind_philip/TankBind/dataset/val_dataset/processed/esm_embeddings.pt", "wb") as file:
    torch.save(esm_embeddings_val, file)