### get data

In [1]:
import pandas as pd
import sqlite3
import json
import os
import glob
import re
from tqdm.notebook import tqdm
import requests

In [8]:
url = "https://rest.uniprot.org/proteomes/search"
params = {
    "query": "organism_name:Flavobacterium AND reference:true",
    "format": "json",
    "size": 500
}
headers = {
    "Accept": "application/json"
}

response = requests.get(url, params=params, headers=headers, timeout=60)
response.raise_for_status()  # raises error for non-200 responses

data = response.json()
n = len(data.get("results", []))
print(f"Number of items: {n}")

# Extract proteome IDs (UPIDs)
upids = [item["id"] for item in data.get("results", []) if "id" in item]
print("Number of UPIDs:", len(upids))
print("First few:", upids[:5])


with open("./data/flavobacterium_reference_proteomes.json", "w", encoding="utf-8") as f:
    f.write(response.text)

print("Saved: flavobacterium_reference_proteomes.json")

Number of items: 279
Number of UPIDs: 279
First few: ['UP000037755', 'UP000175968', 'UP000178198', 'UP000184036', 'UP000184121']
Saved: flavobacterium_reference_proteomes.json


In [9]:
import requests
import gzip
from pathlib import Path
import time

OUT = Path("./data/proteomes_json")
OUT.mkdir(parents=True, exist_ok=True)

for upid in upids:
    print("Downloading", upid)

    url = "https://rest.uniprot.org/uniprotkb/stream"
    params = {
        "query": f"proteome:{upid}",
        "format": "json"
    }

    out_file = OUT / f"{upid}.json.gz"

    with requests.get(url, params=params, stream=True, timeout=120) as r:
        r.raise_for_status()
        with gzip.open(out_file, "wb") as f:
            for chunk in r.iter_content(chunk_size=1024 * 1024):
                if chunk:
                    f.write(chunk)

    print("Saved", out_file)

    # polite delay to avoid rate limiting
    time.sleep(0.2)


Downloading UP000037755
Saved data\proteomes_json\UP000037755.json.gz
Downloading UP000175968
Saved data\proteomes_json\UP000175968.json.gz
Downloading UP000178198
Saved data\proteomes_json\UP000178198.json.gz
Downloading UP000184036
Saved data\proteomes_json\UP000184036.json.gz
Downloading UP000184121
Saved data\proteomes_json\UP000184121.json.gz
Downloading UP000184216
Saved data\proteomes_json\UP000184216.json.gz
Downloading UP000184232
Saved data\proteomes_json\UP000184232.json.gz
Downloading UP000184488
Saved data\proteomes_json\UP000184488.json.gz
Downloading UP000184516
Saved data\proteomes_json\UP000184516.json.gz
Downloading UP000184611
Saved data\proteomes_json\UP000184611.json.gz
Downloading UP000198302
Saved data\proteomes_json\UP000198302.json.gz
Downloading UP000198319
Saved data\proteomes_json\UP000198319.json.gz
Downloading UP000198345
Saved data\proteomes_json\UP000198345.json.gz
Downloading UP000198648
Saved data\proteomes_json\UP000198648.json.gz
Downloading UP000199

In [14]:
import gzip
import json

with gzip.open("./data/proteomes_json/UP000037755.json.gz", "rt", encoding="utf-8") as f:
    data = json.load(f)

print(type(data))
print(len(data["results"]))   # if it's a UniProt stream file

data["results"][1]["annotationScore"]


<class 'dict'>
3833


4.0

### data - annotation filtering - run once

In [15]:
import os
import glob
import json
import gzip
import re
from tqdm import tqdm

# Compiled regular expression for performance
x_re = re.compile(r"X", re.IGNORECASE)

def filter_json_gz_files(input_folder, output_file):
    """
    Reads all .json.gz files in 'input_folder', filters their "results" based on:
      1) annotationScore > 2  (i.e., annotation score >= 3)
      2) no "X" in result["sequence"]["value"] (case-insensitive)

    Note on "X": UniProt uses "X" to denote unknown/ambiguous residues. Keeping fully
    specified sequences avoids artifacts in downstream physicochemical feature extraction
    and transformer embeddings.

    Combines all filtered items into one single JSON list and writes it to 'output_file'.
    """

    combined_filtered = []

    # Gather all .json.gz files in the specified folder
    json_gz_files = glob.glob(os.path.join(input_folder, "*.json.gz"))

    for file_path in tqdm(json_gz_files, desc="Processing JSON.GZ files"):
        # Open gzipped JSON
        with gzip.open(file_path, "rt", encoding="utf-8") as file:
            data = json.load(file)

        # Only process if there's a top-level "results" key
        results = data.get("results", [])
        if not isinstance(results, list):
            continue

        filtered_results = []
        for result in results:
            # 1) annotationScore > 2  (handle float / missing)
            try:
                score = float(result.get("annotationScore", 0))
            except (TypeError, ValueError):
                score = 0.0
            if score <= 2:
                continue

            # 2) no X in sequence
            seq = result.get("sequence", {}).get("value", "") or ""
            if x_re.search(seq):
                continue

            filtered_results.append(result)

        combined_filtered.extend(filtered_results)

    # Write the final combined list (no "results" key) to the output JSON
    with open(output_file, "w", encoding="utf-8") as out_file:
        json.dump(combined_filtered, out_file, indent=2)

    print(f"Wrote {len(combined_filtered)} filtered entries to: {output_file}")


if __name__ == "__main__":
    input_folder = "./data/proteomes_json"  # folder containing UP000....json.gz files
    output_file = "./data/combined_filtered.json"
    filter_json_gz_files(input_folder, output_file)


Processing JSON.GZ files: 100%|██████████| 279/279 [03:03<00:00,  1.52it/s]


Wrote 37469 filtered entries to: ./data/combined_filtered.json


In [10]:
# Compiled regular expression for performance
x_re = re.compile(r"X", re.IGNORECASE)

def filter_json_files(input_folder, output_file):
    """
    Reads all .json files in 'input_folder', filters their "results"
    based on:
      1) annotationScore > 2
      2) no "X" in result["sequence"]["value"] (case-insensitive)

    Combines all filtered items into one single JSON list and writes it to 'output_file'.
    """

    combined_filtered = []
    
    # Gather all .json files in the specified folder
    json_files = glob.glob(os.path.join(input_folder, "*.json"))
    
    # Use tqdm to show a progress bar in Jupyter Notebook
    for json_file_path in tqdm(json_files, desc="Processing JSON files"):
        with open(json_file_path, 'r') as file:
            data = json.load(file)
            
            # Only process if there's a top-level "results" key
            if "results" in data:
                # Filter the results according to the given conditions
                filtered_results = [
                    result
                    for result in data["results"]
                    if int(result.get("annotationScore", 0)) > 2
                    and not x_re.search(result.get("sequence", {}).get("value", ""))
                ]
                
                # Accumulate into the combined list
                combined_filtered.extend(filtered_results)
    
    # Write the final combined list (no "results" key) to the output JSON
    with open(output_file, 'w') as out_file:
        json.dump(combined_filtered, out_file, indent=4)
    
    print(f"Filtered results have been written to: {output_file}")

if __name__ == "__main__":
    # Example usage:
    input_folder = "./data"    # Replace with the folder containing your JSON files
    output_file = "./data/combined_filtered.json"  # Output JSON file
    filter_json_files(input_folder, output_file)


Processing JSON files:   0%|          | 0/39 [00:00<?, ?it/s]

Filtered results have been written to: ./data/combined_filtered.json


### data - go term filtering - run once

In [16]:
# read the json file
with open('./data/combined_filtered.json') as f:
    data = json.load(f)

len(data)

37469

In [28]:
data[0]["uniProtKBCrossReferences"][9]

{'database': 'GO',
 'id': 'GO:0005886',
 'properties': [{'key': 'GoTerm', 'value': 'C:plasma membrane'},
  {'key': 'GoEvidenceType', 'value': 'IEA:UniProtKB-SubCell'}]}

In [29]:
from collections import Counter

print(len(data))

go_ids = []

for item in data:
    cross_refs = item.get("uniProtKBCrossReferences", [])
    for ref in cross_refs:
        if ref.get("database") == "GO":
            go_ids.append(ref.get("id"))  # e.g. "GO:0036125"

# Create a Counter to get the frequency of each GO term
go_counter = Counter(go_ids)

# Total GO terms (including duplicates)
total_go_count = sum(go_counter.values())

# Unique GO terms
unique_go_count = len(go_counter)

print("Total GO terms (including duplicates):", total_go_count)
print("Unique GO terms:", unique_go_count)
print("\nCounts per GO term:")

# Sort go_counter by frequency (descending)
go_counter = dict(sorted(go_counter.items(), key=lambda item: item[1], reverse=True))

# Keep GO terms with frequency >= 4 (rare-term filtering for training stability)
GO_MIN_FREQ = 4

target_go_terms = [
    go_term
    for go_term, count in go_counter.items()
    if count >= GO_MIN_FREQ
]
print(f"Number of target GO terms (freq >= {GO_MIN_FREQ}):", len(target_go_terms))

# Filter items to those that have >=1 GO term in target_go_terms
filtered_data = [
    item
    for item in data
    if any(
        ref.get("id") in target_go_terms
        for ref in item.get("uniProtKBCrossReferences", [])
        if ref.get("database") == "GO"
    )
]
print(f"Number of items after GO filtering (freq >= {GO_MIN_FREQ}):", len(filtered_data))

# ESM-2 model constraint (used later), printed here for traceability
ESM_MAX_LEN = 1022
filtered_data_len_ok = [
    item
    for item in filtered_data
    if len((item.get("sequence") or {}).get("value") or "") <= ESM_MAX_LEN
]
print(f"Number of items after length <= {ESM_MAX_LEN}:", len(filtered_data_len_ok))


37469
Total GO terms (including duplicates): 195632
Unique GO terms: 967

Counts per GO term:
Number of target GO terms: 967
Number of items after filtering: 36552


In [30]:
filtered_data[0]["sequence"]["value"]

'MAEKSKKPTPNKPKFSAWWIYTAIILVFLALNFFSSGSDFGNPKEILENQLITYAKDGDVKRIDISKSKQVAYIYLTEEASQKEKHSKPQDNGMSLLSSSGGADYEFRYGTLENFENTLNDIQENQGIEIPRKYLKADNTFFGDMLLTLLPIAVIIGIWIFIMRRMSSGGAGGAGGQIFNIGKSKAKLFDEKTDVKTSFKDVAGLEGAKEEVQEIVDFLKNPDKYTNLGGKIPKGALLVGPPGTGKTLLAKAVAGEAKVPFFSLSGSDFVEMFVGVGASRVRDLFKQAKEKSPAIIFIDEIDAIGRARGKANFSGSNDERENTLNQLLTEMDGFGTNTNVIVLAATNRADVLDTALMRAGRFDRQIYVDLPDVRERKEIFEVHLRPLKKVAEELDTEFMAKQTPGFSGADIANVCNEAALIAARQGKAAVGRQDFLDAVDRIVGGLEKKNKIITPEEKRAIAFHEAGHATVSWMLEHAAPLVKVTIVPRGRSLGAAWYLPEERLIVRPEQMLDEMCAAMGGRAAEKVTFNKISTGALSDLEKVTKQARMMVTTYGLNDEIGNLTYYDSSGQNEYNFSKPYSERTAELIDKEISKIIEAQYQRAIKILEDNKDKLNELAEVLLDKEVIFKDNLEKIFGKRPFDKGEEVEVNDTSVTES'

In [31]:
import pandas as pd

# Suppose filtered_data is a list of dicts, each dict containing:
#   - "sequence": {"value": ...}
#   - "uniProtKBCrossReferences": [{...}, {...}, ...] (including GO entries)
# And target_go_terms is your list of retained GO IDs (after frequency filtering).

# Create column names: first "sequence", then each of the GO terms
columns = ["sequence"] + target_go_terms

# Prepare an empty list to gather row data
rows = []

for entry in filtered_data:
    # 1. Get the sequence
    seq = entry["sequence"]["value"]

    # 2. Collect all GO IDs for this entry
    cross_refs = entry.get("uniProtKBCrossReferences", [])
    go_ids = {
        ref["id"]
        for ref in cross_refs
        if ref.get("database") == "GO"
    }

    # 3. Build the row:
    #    - first element is the sequence
    #    - then for each GO term in target_go_terms, check membership in go_ids
    row = [seq] + [1 if go_term in go_ids else 0 for go_term in target_go_terms]
    rows.append(row)

# Finally, create the DataFrame
df = pd.DataFrame(rows, columns=columns)

df


Unnamed: 0,sequence,GO:0005737,GO:0005524,GO:0005829,GO:0046872,GO:0000287,GO:0005886,GO:0008270,GO:0071555,GO:0009252,...,GO:0003977,GO:0003678,GO:0004452,GO:0004731,GO:0008782,GO:0008930,GO:0051745,GO:0036381,GO:0043715,GO:0051213
0,MAEKSKKPTPNKPKFSAWWIYTAIILVFLALNFFSSGSDFGNPKEI...,0,1,0,0,0,1,1,0,0,...,0,0,0,0,0,0,0,0,0,0
1,MPEIEKTTYQDTLAYAKQADAADPLARFRDMFNIPKDAQGNNLIYL...,1,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2,MADTKYIFVTGGVSSSLGKGIIAASLAKLLQARGYSVTIQKLDPYI...,0,1,1,1,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3,MKNTQKDIRALSREELRDFFVSQGEKAFRGNQVYEWLWVKGAHSFD...,1,0,0,1,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
4,MILLKDILYKVTLESVTGNPNVPVNAIHFDSRKVGLNDVFVAISGT...,1,1,0,0,1,0,0,1,1,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
36547,MQRDKQIFELILEEQERQIHGLELIASENFVSDEVMEAAGSVLTNK...,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
36548,MKKAIFPGSFDPITLGHEDIIKRGIPLFDEIVIAIGVNAEKKYMFP...,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
36549,MNQTKYIFVTGGVTSSLGKGIIAASLAKLLQARGYRTTIQKFDPYL...,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
36550,MIAHNSKIIGEGLTYDDVLLVPNYSHVLPREVSIKTKFSRNITLNV...,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [32]:
mask = (df[target_go_terms] == 0).all(axis=1)

df[mask]

Unnamed: 0,sequence,GO:0005737,GO:0005524,GO:0005829,GO:0046872,GO:0000287,GO:0005886,GO:0008270,GO:0071555,GO:0009252,...,GO:0003977,GO:0003678,GO:0004452,GO:0004731,GO:0008782,GO:0008930,GO:0051745,GO:0036381,GO:0043715,GO:0051213


In [34]:
df["count"] = df[target_go_terms].sum(axis=1)
df

Unnamed: 0,sequence,GO:0005737,GO:0005524,GO:0005829,GO:0046872,GO:0000287,GO:0005886,GO:0008270,GO:0071555,GO:0009252,...,GO:0003678,GO:0004452,GO:0004731,GO:0008782,GO:0008930,GO:0051745,GO:0036381,GO:0043715,GO:0051213,count
0,MAEKSKKPTPNKPKFSAWWIYTAIILVFLALNFFSSGSDFGNPKEI...,0,1,0,0,0,1,1,0,0,...,0,0,0,0,0,0,0,0,0,8
1,MPEIEKTTYQDTLAYAKQADAADPLARFRDMFNIPKDAQGNNLIYL...,1,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,8
2,MADTKYIFVTGGVSSSLGKGIIAASLAKLLQARGYSVTIQKLDPYI...,0,1,1,1,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,9
3,MKNTQKDIRALSREELRDFFVSQGEKAFRGNQVYEWLWVKGAHSFD...,1,0,0,1,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,9
4,MILLKDILYKVTLESVTGNPNVPVNAIHFDSRKVGLNDVFVAISGT...,1,1,0,0,1,0,0,1,1,...,0,0,0,0,0,0,0,0,0,8
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
36547,MQRDKQIFELILEEQERQIHGLELIASENFVSDEVMEAAGSVLTNK...,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,1
36548,MKKAIFPGSFDPITLGHEDIIKRGIPLFDEIVIAIGVNAEKKYMFP...,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,1
36549,MNQTKYIFVTGGVTSSLGKGIIAASLAKLLQARGYRTTIQKFDPYL...,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,1
36550,MIAHNSKIIGEGLTYDDVLLVPNYSHVLPREVSIKTKFSRNITLNV...,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,1


In [35]:
df.sort_values(by="count", ascending=False)

Unnamed: 0,sequence,GO:0005737,GO:0005524,GO:0005829,GO:0046872,GO:0000287,GO:0005886,GO:0008270,GO:0071555,GO:0009252,...,GO:0003678,GO:0004452,GO:0004731,GO:0008782,GO:0008930,GO:0051745,GO:0036381,GO:0043715,GO:0051213,count
26936,MTNELLHAKLKENFGFEKFRPNQETIINTVLSGQDTLAIMPTGGGK...,1,1,0,1,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,14
15850,MNSNEIEIHKELKKYFGFSQFKGLQEQVITSILNKKNTFVIMPTGG...,1,1,0,1,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,13
4714,MITSQILHTTLKENFGFEKFRSNQEEIINTILQGNDTLAIMPTGGG...,1,1,0,1,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,13
8699,MTSEILHAKLKENFGFEKFRPNQETIITTILSGQDTLAIMPTGGGK...,1,1,0,1,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,13
8718,MNSNEIEIHKELKKYFGFSQFKGLQEQVITSILEKKNTFVIMPTGG...,1,1,0,1,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,13
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
36427,MSTNIKKIGVLTSGGDSPGMNAAIRAVVRACAYHNIECTGIYRGYQ...,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,1
36428,MQYTFYKYQGTGNDFVMIDNRNNLFPKNDTKLVARLCDRKFGIGAD...,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,1
36429,MEYRIEKDTMGEVKVPADKYWGAQTERSRNNFKIGNAASMPKEIVE...,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,1
36430,MSKDNKPENKFKVSPWLIYAGIFLMLIAINFVGGGSFLNGPEKLSI...,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,1


In [36]:
df.drop(columns=["count"], inplace=True)
df

Unnamed: 0,sequence,GO:0005737,GO:0005524,GO:0005829,GO:0046872,GO:0000287,GO:0005886,GO:0008270,GO:0071555,GO:0009252,...,GO:0003977,GO:0003678,GO:0004452,GO:0004731,GO:0008782,GO:0008930,GO:0051745,GO:0036381,GO:0043715,GO:0051213
0,MAEKSKKPTPNKPKFSAWWIYTAIILVFLALNFFSSGSDFGNPKEI...,0,1,0,0,0,1,1,0,0,...,0,0,0,0,0,0,0,0,0,0
1,MPEIEKTTYQDTLAYAKQADAADPLARFRDMFNIPKDAQGNNLIYL...,1,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2,MADTKYIFVTGGVSSSLGKGIIAASLAKLLQARGYSVTIQKLDPYI...,0,1,1,1,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3,MKNTQKDIRALSREELRDFFVSQGEKAFRGNQVYEWLWVKGAHSFD...,1,0,0,1,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
4,MILLKDILYKVTLESVTGNPNVPVNAIHFDSRKVGLNDVFVAISGT...,1,1,0,0,1,0,0,1,1,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
36547,MQRDKQIFELILEEQERQIHGLELIASENFVSDEVMEAAGSVLTNK...,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
36548,MKKAIFPGSFDPITLGHEDIIKRGIPLFDEIVIAIGVNAEKKYMFP...,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
36549,MNQTKYIFVTGGVTSSLGKGIIAASLAKLLQARGYRTTIQKFDPYL...,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
36550,MIAHNSKIIGEGLTYDDVLLVPNYSHVLPREVSIKTKFSRNITLNV...,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [37]:
df.to_csv("./data/seq_go.csv", index=False)

### data - esm and bioPython columns - run once

In [2]:
import pandas as pd
import torch
from transformers import EsmTokenizer, EsmModel
from tqdm import tqdm
import numpy as np

# BioPython imports
# !pip install biopython  # Make sure biopython is installed
from Bio.SeqUtils.ProtParam import ProteinAnalysis


In [3]:

###############################################################################
# 1) LOAD CSV
###############################################################################
# Suppose your DataFrame is stored as "seq_go.csv" with at least these columns:
#   - "sequence"
#   - "GO:0005829"
#   - "GO:0005737"
#   ... etc.
# We read it in as follows:

df = pd.read_csv("./data/seq_go.csv")
print(f"Number of rows: {len(df)}")

import re

# 1. Filter out sequences longer than 1022
df['Length'] = df['sequence'].str.len()
df = df[df['Length'] <= 1022]

# 2. Filter out invalid amino acid characters
def has_invalid_chars(sequence):
    return bool(re.search(r'[^ACDEFGHIKLMNPQRSTVWY]', sequence)) or ' ' in sequence

invalid_sequences_mask = df['sequence'].apply(has_invalid_chars)
df = df[~invalid_sequences_mask].reset_index(drop=True)

print("DataFrame after filtering by length and invalid characters:")
# print(df.head())
print(f"Total valid sequences: {len(df)}")


print("Initial DataFrame:")
df


Number of rows: 36552
DataFrame after filtering by length and invalid characters:
Total valid sequences: 35573
Initial DataFrame:


Unnamed: 0,sequence,GO:0005737,GO:0005524,GO:0005829,GO:0046872,GO:0000287,GO:0005886,GO:0008270,GO:0071555,GO:0009252,...,GO:0003678,GO:0004452,GO:0004731,GO:0008782,GO:0008930,GO:0051745,GO:0036381,GO:0043715,GO:0051213,Length
0,MAEKSKKPTPNKPKFSAWWIYTAIILVFLALNFFSSGSDFGNPKEI...,0,1,0,0,0,1,1,0,0,...,0,0,0,0,0,0,0,0,0,657
1,MPEIEKTTYQDTLAYAKQADAADPLARFRDMFNIPKDAQGNNLIYL...,1,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,429
2,MADTKYIFVTGGVSSSLGKGIIAASLAKLLQARGYSVTIQKLDPYI...,0,1,1,1,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,546
3,MKNTQKDIRALSREELRDFFVSQGEKAFRGNQVYEWLWVKGAHSFD...,1,0,0,1,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,347
4,MILLKDILYKVTLESVTGNPNVPVNAIHFDSRKVGLNDVFVAISGT...,1,1,0,0,1,0,0,1,1,...,0,0,0,0,0,0,0,0,0,487
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
35568,MQRDKQIFELILEEQERQIHGLELIASENFVSDEVMEAAGSVLTNK...,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,424
35569,MKKAIFPGSFDPITLGHEDIIKRGIPLFDEIVIAIGVNAEKKYMFP...,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,151
35570,MNQTKYIFVTGGVTSSLGKGIIAASLAKLLQARGYRTTIQKFDPYL...,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,537
35571,MIAHNSKIIGEGLTYDDVLLVPNYSHVLPREVSIKTKFSRNITLNV...,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,490


In [4]:

###############################################################################
# 3) ADD BIOPYTHON FEATURES
###############################################################################
# We'll define a function to calculate sequence-based features using BioPython.

def extract_features(sequence):
    """
    Calculate basic protein features via BioPython’s ProteinAnalysis.
    """
    analysis = ProteinAnalysis(sequence)
    
    return {
        'Amino_Acid_Composition': analysis.count_amino_acids(),
        "Molecular_Weight": analysis.molecular_weight(),
        "Isoelectric_Point": analysis.isoelectric_point(),
        "Aromaticity": analysis.aromaticity(),
        "Instability_Index": analysis.instability_index(),
        "GRAVY": analysis.gravy(),
        "Helix_Fraction": analysis.secondary_structure_fraction()[0],
        "Turn_Fraction": analysis.secondary_structure_fraction()[1],
        "Sheet_Fraction": analysis.secondary_structure_fraction()[2],
    }

# Apply this feature extraction
biopy_features_series = df["sequence"].apply(extract_features)
biopy_features_df = pd.DataFrame(biopy_features_series.tolist())

# Merge BioPython features into original DataFrame
df = pd.concat([df, biopy_features_df], axis=1)
print("\nDataFrame after adding BioPython features:")
df



DataFrame after adding BioPython features:


Unnamed: 0,sequence,GO:0005737,GO:0005524,GO:0005829,GO:0046872,GO:0000287,GO:0005886,GO:0008270,GO:0071555,GO:0009252,...,Length,Amino_Acid_Composition,Molecular_Weight,Isoelectric_Point,Aromaticity,Instability_Index,GRAVY,Helix_Fraction,Turn_Fraction,Sheet_Fraction
0,MAEKSKKPTPNKPKFSAWWIYTAIILVFLALNFFSSGSDFGNPKEI...,0,1,0,0,0,1,1,0,0,...,657,"{'A': 60, 'C': 2, 'D': 41, 'E': 56, 'F': 33, '...",72775.1504,5.549668,0.082192,24.245814,-0.322374,0.371385,0.280061,0.347032
1,MPEIEKTTYQDTLAYAKQADAADPLARFRDMFNIPKDAQGNNLIYL...,1,0,0,0,0,0,0,0,0,...,429,"{'A': 38, 'C': 6, 'D': 24, 'E': 33, 'F': 15, '...",48664.8268,5.540745,0.102564,33.233124,-0.418881,0.337995,0.282051,0.331002
2,MADTKYIFVTGGVSSSLGKGIIAASLAKLLQARGYSVTIQKLDPYI...,0,1,1,1,0,0,0,0,0,...,546,"{'A': 35, 'C': 5, 'D': 31, 'E': 46, 'F': 13, '...",60902.6021,5.591331,0.069597,32.644322,-0.297619,0.331502,0.265568,0.377289
3,MKNTQKDIRALSREELRDFFVSQGEKAFRGNQVYEWLWVKGAHSFD...,1,0,0,1,0,0,0,0,0,...,347,"{'A': 23, 'C': 6, 'D': 28, 'E': 19, 'F': 13, '...",39554.8706,7.593257,0.074928,40.103170,-0.400288,0.305476,0.276657,0.348703
4,MILLKDILYKVTLESVTGNPNVPVNAIHFDSRKVGLNDVFVAISGT...,1,1,0,0,1,0,0,1,1,...,487,"{'A': 42, 'C': 4, 'D': 30, 'E': 35, 'F': 21, '...",53976.7475,5.420587,0.080082,26.392608,-0.207392,0.347023,0.254620,0.392197
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
35568,MQRDKQIFELILEEQERQIHGLELIASENFVSDEVMEAAGSVLTNK...,0,0,0,0,0,0,0,0,0,...,424,"{'A': 40, 'C': 3, 'D': 25, 'E': 33, 'F': 19, '...",46507.5373,5.247455,0.073113,38.398608,-0.140330,0.339623,0.280660,0.341981
35569,MKKAIFPGSFDPITLGHEDIIKRGIPLFDEIVIAIGVNAEKKYMFP...,0,0,0,0,0,0,0,0,0,...,151,"{'A': 10, 'C': 1, 'D': 6, 'E': 14, 'F': 10, 'G...",17129.8181,8.966537,0.092715,36.262252,-0.060265,0.350993,0.238411,0.403974
35570,MNQTKYIFVTGGVTSSLGKGIIAASLAKLLQARGYRTTIQKFDPYL...,0,0,0,0,0,0,0,0,0,...,537,"{'A': 38, 'C': 6, 'D': 25, 'E': 45, 'F': 16, '...",59938.8279,5.900024,0.080074,29.729236,-0.241527,0.335196,0.260708,0.385475
35571,MIAHNSKIIGEGLTYDDVLLVPNYSHVLPREVSIKTKFSRNITLNV...,0,0,0,0,0,0,0,0,0,...,490,"{'A': 47, 'C': 3, 'D': 26, 'E': 30, 'F': 12, '...",52127.4770,7.101941,0.044898,22.595510,0.008776,0.324490,0.285714,0.367347


In [5]:

###############################################################################
# 4) ADD ESM-2 EMBEDDINGS
###############################################################################
# We'll use the "facebook/esm2_t6_8M_UR50D" model as an example.  
# (Feel free to pick a different ESM model if you prefer.)

model_name = "facebook/esm2_t6_8M_UR50D"
tokenizer = EsmTokenizer.from_pretrained(model_name)
model = EsmModel.from_pretrained(model_name).eval()  # Put model in eval mode

# Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

def extract_esm_embeddings(sequence):
    """
    Tokenize the sequence and obtain the mean-pooled ESM-2 embedding.
    """
    # Truncate at 1022 if necessary
    inputs = tokenizer(sequence, return_tensors="pt", truncation=True, max_length=1022)
    inputs = {key: val.to(device) for key, val in inputs.items()}
    
    with torch.no_grad():
        outputs = model(**inputs)
        # Mean across token embeddings to get one embedding vector per sequence
        embedding = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()
    return embedding

# Loop over each sequence, extract embeddings, and store in a list
embeddings = []
for seq in tqdm(df["sequence"], desc="Extracting ESM-2 Embeddings"):
    emb = extract_esm_embeddings(seq)
    embeddings.append(emb)

# Add the embeddings to the DataFrame
df["ESM_Embedding"] = embeddings
print("\nDataFrame after adding ESM embeddings (stored as NumPy arrays):")
df


Loading weights:   0%|          | 0/107 [00:00<?, ?it/s]

[1mEsmModel LOAD REPORT[0m from: facebook/esm2_t6_8M_UR50D
Key                         | Status     | 
----------------------------+------------+-
lm_head.bias                | UNEXPECTED | 
lm_head.dense.bias          | UNEXPECTED | 
lm_head.dense.weight        | UNEXPECTED | 
lm_head.layer_norm.bias     | UNEXPECTED | 
lm_head.layer_norm.weight   | UNEXPECTED | 
esm.embeddings.position_ids | UNEXPECTED | 
pooler.dense.weight         | MISSING    | 
pooler.dense.bias           | MISSING    | 

[3mNotes:
- UNEXPECTED[3m	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.
- MISSING[3m	:those params were newly initialized because missing from the checkpoint. Consider training on your downstream task.[0m
Extracting ESM-2 Embeddings: 100%|██████████| 35573/35573 [14:36<00:00, 40.60it/s]


DataFrame after adding ESM embeddings (stored as NumPy arrays):





Unnamed: 0,sequence,GO:0005737,GO:0005524,GO:0005829,GO:0046872,GO:0000287,GO:0005886,GO:0008270,GO:0071555,GO:0009252,...,Amino_Acid_Composition,Molecular_Weight,Isoelectric_Point,Aromaticity,Instability_Index,GRAVY,Helix_Fraction,Turn_Fraction,Sheet_Fraction,ESM_Embedding
0,MAEKSKKPTPNKPKFSAWWIYTAIILVFLALNFFSSGSDFGNPKEI...,0,1,0,0,0,1,1,0,0,...,"{'A': 60, 'C': 2, 'D': 41, 'E': 56, 'F': 33, '...",72775.1504,5.549668,0.082192,24.245814,-0.322374,0.371385,0.280061,0.347032,"[0.035795994, -0.06988472, 0.06342121, 0.25773..."
1,MPEIEKTTYQDTLAYAKQADAADPLARFRDMFNIPKDAQGNNLIYL...,1,0,0,0,0,0,0,0,0,...,"{'A': 38, 'C': 6, 'D': 24, 'E': 33, 'F': 15, '...",48664.8268,5.540745,0.102564,33.233124,-0.418881,0.337995,0.282051,0.331002,"[0.03573818, -0.07426745, 0.10044023, 0.237400..."
2,MADTKYIFVTGGVSSSLGKGIIAASLAKLLQARGYSVTIQKLDPYI...,0,1,1,1,0,0,0,0,0,...,"{'A': 35, 'C': 5, 'D': 31, 'E': 46, 'F': 13, '...",60902.6021,5.591331,0.069597,32.644322,-0.297619,0.331502,0.265568,0.377289,"[-0.07994358, -0.107176796, -0.00058884954, 0...."
3,MKNTQKDIRALSREELRDFFVSQGEKAFRGNQVYEWLWVKGAHSFD...,1,0,0,1,0,0,0,0,0,...,"{'A': 23, 'C': 6, 'D': 28, 'E': 19, 'F': 13, '...",39554.8706,7.593257,0.074928,40.103170,-0.400288,0.305476,0.276657,0.348703,"[-0.16805257, -0.26347956, 0.13902146, 0.21469..."
4,MILLKDILYKVTLESVTGNPNVPVNAIHFDSRKVGLNDVFVAISGT...,1,1,0,0,1,0,0,1,1,...,"{'A': 42, 'C': 4, 'D': 30, 'E': 35, 'F': 21, '...",53976.7475,5.420587,0.080082,26.392608,-0.207392,0.347023,0.254620,0.392197,"[-0.104022294, -0.100982, 0.13634521, 0.245173..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
35568,MQRDKQIFELILEEQERQIHGLELIASENFVSDEVMEAAGSVLTNK...,0,0,0,0,0,0,0,0,0,...,"{'A': 40, 'C': 3, 'D': 25, 'E': 33, 'F': 19, '...",46507.5373,5.247455,0.073113,38.398608,-0.140330,0.339623,0.280660,0.341981,"[0.10813107, -0.032882903, 0.04068126, 0.22721..."
35569,MKKAIFPGSFDPITLGHEDIIKRGIPLFDEIVIAIGVNAEKKYMFP...,0,0,0,0,0,0,0,0,0,...,"{'A': 10, 'C': 1, 'D': 6, 'E': 14, 'F': 10, 'G...",17129.8181,8.966537,0.092715,36.262252,-0.060265,0.350993,0.238411,0.403974,"[-0.20440412, -0.21006401, -0.056837954, 0.310..."
35570,MNQTKYIFVTGGVTSSLGKGIIAASLAKLLQARGYRTTIQKFDPYL...,0,0,0,0,0,0,0,0,0,...,"{'A': 38, 'C': 6, 'D': 25, 'E': 45, 'F': 16, '...",59938.8279,5.900024,0.080074,29.729236,-0.241527,0.335196,0.260708,0.385475,"[-0.070707686, -0.11198315, -0.026034024, 0.24..."
35571,MIAHNSKIIGEGLTYDDVLLVPNYSHVLPREVSIKTKFSRNITLNV...,0,0,0,0,0,0,0,0,0,...,"{'A': 47, 'C': 3, 'D': 26, 'E': 30, 'F': 12, '...",52127.4770,7.101941,0.044898,22.595510,0.008776,0.324490,0.285714,0.367347,"[0.03842143, -0.0030931407, 0.15231434, 0.0721..."


In [6]:
len(df["ESM_Embedding"].iloc[0])

320

In [7]:
# Assuming `df` is your DataFrame and "ESM_Embedding" contains lists (or NumPy arrays)
df2 = df.copy()

# Expand the "ESM_Embedding" column into separate columns
esm_embedding_columns = pd.DataFrame(df2["ESM_Embedding"].tolist(), columns=[f"ESM_Dim_{i}" for i in range(len(df2["ESM_Embedding"].iloc[0]))])

# Expand the "Amino_Acid_Composition" dictionary into separate columns
amino_acid_comp_columns = pd.DataFrame(df2["Amino_Acid_Composition"].apply(pd.Series))

# Concatenate the expanded columns with the original DataFrame (excluding the original columns)
df2 = pd.concat([df2.drop(columns=["ESM_Embedding", "Amino_Acid_Composition"]), esm_embedding_columns, amino_acid_comp_columns], axis=1)

# Save the final DataFrame with expanded features into separate columns
output_file = "./data/final_df_with_features_expanded.csv"
df2.to_csv(output_file, index=False)
print(f"\nFinal DataFrame with BioPython features + expanded ESM embeddings + expanded Amino Acid Composition saved as '{output_file}'.")



Final DataFrame with BioPython features + expanded ESM embeddings + expanded Amino Acid Composition saved as './data/final_df_with_features_expanded.csv'.


In [8]:
df2

Unnamed: 0,sequence,GO:0005737,GO:0005524,GO:0005829,GO:0046872,GO:0000287,GO:0005886,GO:0008270,GO:0071555,GO:0009252,...,M,N,P,Q,R,S,T,V,W,Y
0,MAEKSKKPTPNKPKFSAWWIYTAIILVFLALNFFSSGSDFGNPKEI...,0,1,0,0,0,1,1,0,0,...,15,32,23,19,30,36,33,39,5,16
1,MPEIEKTTYQDTLAYAKQADAADPLARFRDMFNIPKDAQGNNLIYL...,1,0,0,0,0,0,0,0,0,...,15,23,25,17,20,19,15,19,8,21
2,MADTKYIFVTGGVSSSLGKGIIAASLAKLLQARGYSVTIQKLDPYI...,0,1,1,1,0,0,0,0,0,...,10,22,19,21,22,33,33,43,4,21
3,MKNTQKDIRALSREELRDFFVSQGEKAFRGNQVYEWLWVKGAHSFD...,1,0,0,1,0,0,0,0,0,...,11,20,12,17,21,21,16,28,4,9
4,MILLKDILYKVTLESVTGNPNVPVNAIHFDSRKVGLNDVFVAISGT...,1,1,0,0,1,0,0,1,1,...,10,27,12,18,13,25,39,35,1,17
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
35568,MQRDKQIFELILEEQERQIHGLELIASENFVSDEVMEAAGSVLTNK...,0,0,0,0,0,0,0,0,0,...,17,20,16,13,18,20,20,30,1,11
35569,MKKAIFPGSFDPITLGHEDIIKRGIPLFDEIVIAIGVNAEKKYMFP...,0,0,0,0,0,0,0,0,0,...,3,6,7,0,9,8,8,9,0,4
35570,MNQTKYIFVTGGVTSSLGKGIIAASLAKLLQARGYRTTIQKFDPYL...,0,0,0,0,0,0,0,0,0,...,11,28,21,15,21,27,34,41,3,24
35571,MIAHNSKIIGEGLTYDDVLLVPNYSHVLPREVSIKTKFSRNITLNV...,0,0,0,0,0,0,0,0,0,...,13,19,20,13,22,26,32,52,0,10
