<a href="https://colab.research.google.com/github/Sridipta-Roy/Protein-Function-Prediction/blob/main/New_Proteins_Gen.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
from google.colab import drive
import json
import re
import time
import requests
from tqdm.notebook import tqdm
from collections import OrderedDict

In [None]:
drive.mount('/content/drive')
os.chdir('/content/drive/MyDrive/protein-multimodal')

PROJECT_ROOT = "/content/drive/MyDrive/protein-multimodal"
DATA_DIR = f"{PROJECT_ROOT}/data"
RAW_DIR = f"{DATA_DIR}/raw"
PROCESSED_DIR = f"{DATA_DIR}/processed"
NEW_DATA_DIR = f"{DATA_DIR}/new_proteins"

os.makedirs(NEW_DATA_DIR, exist_ok=True)


Mounted at /content/drive


In [None]:
def load_all_existing_accessions():
    """Load accessions from ALL existing data sources"""
    existing_accessions = set()

    # Check all possible sources
    sources = [
        (f"{RAW_DIR}/proteins_raw.json", "Raw data"),
        (f"{PROCESSED_DIR}/train.json", "Train split"),
        (f"{PROCESSED_DIR}/val.json", "Val split"),
        (f"{PROCESSED_DIR}/test.json", "Test split")
    ]

    print("\nüìÇ Loading existing accessions from all sources...")

    for filepath, description in sources:
        if os.path.exists(filepath):
            try:
                with open(filepath, "r") as f:
                    data = json.load(f)
                    new_accessions = {p["accession"] for p in data}
                    count = len(new_accessions - existing_accessions)
                    existing_accessions.update(new_accessions)
                    print(f"  ‚úì {description}: +{count} accessions (total: {len(existing_accessions)})")
            except Exception as e:
                print(f"  ‚ö†Ô∏è  {description}: Error loading - {e}")

    if not existing_accessions:
        print("\n‚ö†Ô∏è  No existing data files found!")
    else:
        print(f"\n‚úÖ Total existing accessions to exclude: {len(existing_accessions)}")

    return existing_accessions

existing_accessions = load_all_existing_accessions()


üìÇ Loading existing accessions from all sources...
  ‚úì Raw data: +413 accessions (total: 413)
  ‚úì Train split: +0 accessions (total: 413)
  ‚úì Val split: +0 accessions (total: 413)
  ‚úì Test split: +0 accessions (total: 413)

‚úÖ Total existing accessions to exclude: 413


In [None]:
class UniProtNewProteinFetcher:
    """Fetches NEW proteins not in existing dataset"""

    BASE_URL = "https://rest.uniprot.org/uniprotkb/search"

    def __init__(self, max_proteins=100, existing_accessions=None):
        self.max_proteins = max_proteins
        self.existing_accessions = existing_accessions or set()
        self.proteins = []
        self.seen_accessions = set()

    def fetch_new_proteins(self, query="reviewed:true AND annotation_score:5 AND organism_id:9606"):
        """Fetch proteins that are NOT in existing dataset"""

        print(f"\nüîç Fetching {self.max_proteins} NEW unique proteins...")
        print(f"üìã Query: {query}")
        print(f"üö´ Excluding: {len(self.existing_accessions)} existing accessions\n")

        params = {
            "query": query,
            "format": "json",
            "size": 500,
            "fields": (
                "accession,id,protein_name,cc_function,sequence,length,"
                "organism_name,gene_names,go_id,ec"
            ),
        }

        offset = 0
        total_checked = 0
        already_exists = 0
        duplicates = 0
        filtered = 0
        max_offset = 50000  # For 100 proteins, we don't need to go too far

        with tqdm(total=self.max_proteins, desc="Fetching new proteins") as pbar:
            while len(self.proteins) < self.max_proteins and offset < max_offset:
                params["offset"] = offset

                try:
                    response = requests.get(self.BASE_URL, params=params, timeout=30)
                    response.raise_for_status()
                    data = response.json()

                    if "results" not in data or len(data["results"]) == 0:
                        print(f"\n‚ö†Ô∏è  No more results at offset {offset}")
                        break

                    for entry in data["results"]:
                        total_checked += 1

                        accession = entry.get("primaryAccession", "")

                        # Check if already in existing dataset
                        if accession in self.existing_accessions:
                            already_exists += 1
                            continue

                        # Check if duplicate in current fetch
                        if accession in self.seen_accessions:
                            duplicates += 1
                            continue

                        # Mark as seen
                        self.seen_accessions.add(accession)

                        # Parse and validate
                        protein_data = self._parse_entry(entry)
                        if protein_data:
                            self.proteins.append(protein_data)
                            pbar.update(1)

                            # Stop if target reached
                            if len(self.proteins) >= self.max_proteins:
                                print(f"\n‚úÖ Target reached!")
                                break
                        else:
                            filtered += 1

                    offset += len(data["results"])

                    pbar.set_postfix({
                        'new': len(self.proteins),
                        'exists': already_exists,
                        'filtered': filtered
                    })

                    time.sleep(0.2)

                except requests.RequestException as e:
                    print(f"\n‚ùå Error: {e}")
                    time.sleep(1)
                    continue

        print(f"\n‚úÖ Successfully fetched {len(self.proteins)} NEW unique proteins")
        print(f"  üìä Statistics:")
        print(f"     - Total checked: {total_checked}")
        print(f"     - New proteins: {len(self.proteins)}")
        print(f"     - Already existed: {already_exists}")
        print(f"     - Duplicates: {duplicates}")
        print(f"     - Filtered (quality): {filtered}")

        return self.proteins

    def _parse_entry(self, entry):
        """Parse and validate UniProt entry"""
        try:
            accession = entry.get("primaryAccession", "")

            # Quick validation
            sequence = entry.get("sequence", {}).get("value", "")
            seq_length = entry.get("sequence", {}).get("length", 0)

            if not sequence or seq_length < 50 or seq_length > 1000:
                return None

            # Get function
            function = ""
            comments = entry.get("comments", [])
            for comment in comments:
                if comment.get("commentType") == "FUNCTION":
                    for text in comment.get("texts", []):
                        function += text.get("value", "") + " "
                    break

            function = function.strip()
            if not function:
                return None

            # Clean function
            function = self._clean_text(function)
            if len(function) < 50:
                return None

            # Check for valid amino acids
            valid_aas = set("ACDEFGHIKLMNPQRSTVWY")
            if not all(aa in valid_aas for aa in sequence):
                return None

            # Get other fields
            protein_name = (
                entry.get("proteinDescription", {})
                .get("recommendedName", {})
                .get("fullName", {})
                .get("value", "")
            )

            organism = entry.get("organism", {}).get("scientificName", "")

            genes = entry.get("genes", [])
            gene_name = genes[0].get("geneName", {}).get("value", "") if genes else ""

            # GO terms
            go_terms = []
            for xref in entry.get("uniProtKBCrossReferences", []):
                if xref.get("database") == "GO":
                    go_id = xref.get("id")
                    go_term = None
                    for prop in xref.get("properties", []):
                        if prop.get("key") == "GoTerm":
                            go_term = prop.get("value")
                            break
                    go_terms.append({"go_id": go_id, "go_term": go_term})

            # EC numbers
            ec_numbers = []
            ec_list = (
                entry.get("proteinDescription", {})
                .get("recommendedName", {})
                .get("ecNumbers", [])
            )
            for ec in ec_list:
                value = ec.get("value")
                if value:
                    ec_numbers.append(value)

            return {
                "accession": accession,
                "protein_name": protein_name,
                "gene_name": gene_name,
                "organism": organism,
                "sequence": sequence,
                "length": seq_length,
                "function": function,
                "function_words": len(function.split()),
                "go_terms": go_terms,
                "ec_numbers": ec_numbers,
            }

        except Exception:
            return None

    def _clean_text(self, text: str) -> str:
        """Clean and normalize text"""
        text = re.sub(r"\{[^}]+\}", " ", text)
        text = re.sub(r"\s+", " ", text).strip()

        sentences = re.split(r"(?<=[.!?])\s+", text)
        if len(sentences) > 3:
            text = " ".join(sentences[:3])

        text = re.sub(r"[^\w\s\.,;:()\-/]", "", text)
        text = re.sub(r"\s+", " ", text).strip()
        return text


In [None]:
TARGET_NEW_PROTEINS = 100

fetcher = UniProtNewProteinFetcher(
    max_proteins=TARGET_NEW_PROTEINS,
    existing_accessions=existing_accessions
)

new_proteins = fetcher.fetch_new_proteins(
    query="reviewed:true AND annotation_score:4 AND organism_id:9606"
)


üîç Fetching 100 NEW unique proteins...
üìã Query: reviewed:true AND annotation_score:4 AND organism_id:9606
üö´ Excluding: 413 existing accessions



Fetching new proteins:   0%|          | 0/100 [00:00<?, ?it/s]


‚úÖ Target reached!

‚úÖ Successfully fetched 100 NEW unique proteins
  üìä Statistics:
     - Total checked: 191
     - New proteins: 100
     - Already existed: 0
     - Duplicates: 0
     - Filtered (quality): 91


In [None]:
new_accessions = set([p["accession"] for p in new_proteins])
overlaps = new_accessions & existing_accessions

if overlaps:
    print(f"\n‚ùå ERROR: Found {len(overlaps)} overlapping proteins!")
    print(f"   Sample overlaps: {list(overlaps)[:5]}")
    print(f"   Removing overlaps...")

    # Remove overlaps
    new_proteins = [p for p in new_proteins if p["accession"] not in overlaps]
    new_accessions = set([p["accession"] for p in new_proteins])

    print(f"   After removal: {len(new_proteins)} proteins")
else:
    print(f"\n‚úÖ VERIFIED: All {len(new_proteins)} proteins are NEW")
    print(f"   No overlaps with existing dataset")

# Check for duplicates within new proteins
if len(new_proteins) != len(new_accessions):
    print(f"\n‚ö†Ô∏è  Warning: {len(new_proteins) - len(new_accessions)} duplicates in new proteins")
else:
    print(f"‚úÖ No duplicates within new proteins")


‚úÖ VERIFIED: All 100 proteins are NEW
   No overlaps with existing dataset
‚úÖ No duplicates within new proteins


In [None]:
if new_proteins:
    # Save raw new proteins
    output_file = f"{NEW_DATA_DIR}/new_proteins_100_raw.json"
    with open(output_file, "w") as f:
        json.dump(new_proteins, f, indent=2)

    print(f"\nüíæ Saved {len(new_proteins)} new proteins to:")
    print(f"   {output_file}")

    # Save summary
    summary = {
        "total_new_proteins": len(new_proteins),
        "target": TARGET_NEW_PROTEINS,
        "existing_proteins_excluded": len(existing_accessions),
        "query_used": "reviewed:true AND annotation_score:5 AND organism_id:9606",
        "avg_sequence_length": sum(p["length"] for p in new_proteins) / len(new_proteins),
        "avg_function_words": sum(p["function_words"] for p in new_proteins) / len(new_proteins),
        "min_length": min(p["length"] for p in new_proteins),
        "max_length": max(p["length"] for p in new_proteins),
    }

    summary_file = f"{NEW_DATA_DIR}/new_proteins_100_summary.json"
    with open(summary_file, "w") as f:
        json.dump(summary, f, indent=2)

    print(f"   {summary_file}")

    # Display sample
    print(f"\nüî¨ Sample new protein:")
    sample = new_proteins[0]
    print(f"   Accession: {sample['accession']}")
    print(f"   Name: {sample['protein_name']}")
    print(f"   Length: {sample['length']} aa")
    print(f"   Sequence: {sample['sequence'][:50]}...")
    print(f"   Function: {sample['function'][:100]}...")

    # Display statistics
    print(f"\nüìä Dataset Statistics:")
    print(f"   Total proteins: {len(new_proteins)}")
    print(f"   Avg sequence length: {summary['avg_sequence_length']:.1f} aa")
    print(f"   Avg function words: {summary['avg_function_words']:.1f}")
    print(f"   Length range: {summary['min_length']}-{summary['max_length']} aa")

    print(f"\n‚úÖ SUCCESS! Fetched {len(new_proteins)} new unique proteins")
    print(f"   Ready for embedding generation")
else:
    print(f"\n‚ö†Ô∏è  No new proteins fetched")
    print(f"   Try:")
    print(f"   1. Lowering annotation_score to 4")
    print(f"   2. Including more organisms")
    print(f"   3. Removing organism filter")


üíæ Saved 100 new proteins to:
   /content/drive/MyDrive/protein-multimodal/data/new_proteins_100/new_proteins_100_raw.json
   /content/drive/MyDrive/protein-multimodal/data/new_proteins_100/new_proteins_100_summary.json

üî¨ Sample new protein:
   Accession: A0PJW6
   Name: Transmembrane protein 223
   Length: 202 aa
   Sequence: MAAPWRRWPTGLLAVLRPLLTCRPLQGTTLQRDVLLFEHDRGRFFTILGL...
   Function: Mitochondrial ribosome-associated protein involved in the first steps of cytochrome c oxidase comple...

üìä Dataset Statistics:
   Total proteins: 100
   Avg sequence length: 398.9 aa
   Avg function words: 31.2
   Length range: 56-959 aa

‚úÖ SUCCESS! Fetched 100 new unique proteins
   Ready for embedding generation


### **GENERATE EMBEDDINGS FOR NEW PROTEINS**

In [None]:
if False:
    print("\n‚ö†Ô∏è  No new proteins to generate embeddings for")
else:
    print("\n" + "=" * 80)
    print("GENERATING EMBEDDINGS FOR NEW PROTEINS")
    print("=" * 80)

    # Install ESM if not already installed
    !pip install -q fair-esm umap-learn matplotlib

    import torch
    import esm
    import numpy as np
    import pickle
    import matplotlib.pyplot as plt
    from sklearn.decomposition import PCA
    import umap

    # Setup directories
    EMBEDDINGS_DIR = f"{NEW_DATA_DIR}/embeddings"
    RESIDUE_DIR = f"{EMBEDDINGS_DIR}/residue_level"

    os.makedirs(EMBEDDINGS_DIR, exist_ok=True)
    os.makedirs(RESIDUE_DIR, exist_ok=True)

    # Check device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"\nüñ•Ô∏è  Device: {device}")
    if device.type == "cuda":
        print(f"   GPU: {torch.cuda.get_device_name(0)}")
        print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

    """### **Load ESM-2 Model**"""

    print("\nüîß Loading ESM-2 model...")

    model_name = "esm2_t33_650M_UR50D"
    model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
    model = model.to(device)
    model.eval()

    batch_converter = alphabet.get_batch_converter()

    print(f"‚úì Model loaded: {model_name}")
    print(f"  Embedding dimension: 1280")
    print(f"  Device: {device}")

    """### **ESM-2 Embedding Generator**"""

    class ESM2EmbeddingGenerator:
        """Generates protein embeddings using ESM-2"""

        def __init__(self, model, alphabet, device="cuda", batch_size=8):
            self.model = model
            self.alphabet = alphabet
            self.batch_converter = alphabet.get_batch_converter()
            self.device = device
            self.batch_size = batch_size

        def generate_embeddings(
            self,
            proteins,
            pool_strategy: str = "mean",
            save_residue: bool = True,
            residue_dir: str = None,
        ):
            """
            Generate embeddings for a list of proteins

            Args:
                proteins: List[dict] with keys: 'accession', 'sequence', 'length', 'function'
                pool_strategy: 'mean', 'cls', or 'both'
                save_residue: if True, saves per-residue embeddings (L, 1280)
                residue_dir: directory to store per-residue embeddings

            Returns:
                embeddings: Tensor of shape [N, D] (D = 1280 or 2560)
                metadata: List[dict] with protein info
            """
            if save_residue and residue_dir is None:
                raise ValueError("residue_dir must be provided when save_residue=True")

            embeddings = []
            metadata = []

            for i in tqdm(range(0, len(proteins), self.batch_size), desc="Generating embeddings"):
                batch_proteins = proteins[i : i + self.batch_size]
                batch_data = [(p["accession"], p["sequence"]) for p in batch_proteins]

                try:
                    batch_embeddings = self._process_batch(
                        batch_data=batch_data,
                        pool_strategy=pool_strategy,
                        save_residue=save_residue,
                        residue_dir=residue_dir,
                    )

                    for j, protein in enumerate(batch_proteins):
                        embeddings.append(batch_embeddings[j].cpu())
                        metadata.append({
                            "accession": protein["accession"],
                            "protein_name": protein.get("protein_name", ""),
                            "length": protein["length"],
                            "sequence": protein["sequence"],
                            "function": protein["function"],
                            "go_terms": protein.get("go_terms", []),
                            "ec_numbers": protein.get("ec_numbers", []),
                        })

                except RuntimeError as e:
                    print(f"\n‚ö†Ô∏è  Skipping batch starting at index {i} due to error: {e}")
                    continue

            # Stack and convert to float16
            embeddings_tensor = torch.stack(embeddings).half()
            return embeddings_tensor, metadata

        def _process_batch(
            self,
            batch_data,
            pool_strategy: str = "mean",
            save_residue: bool = True,
            residue_dir: str = None,
        ):
            """Process a single batch of sequences"""
            batch_labels, batch_strs, batch_tokens = self.batch_converter(batch_data)
            batch_tokens = batch_tokens.to(self.device)

            with torch.no_grad():
                results = self.model(batch_tokens, repr_layers=[33], return_contacts=False)

            token_representations = results["representations"][33]

            batch_embeddings = []
            for i, (label, seq) in enumerate(zip(batch_labels, batch_strs)):
                seq_len = len(seq)

                # Per-residue representation (L, 1280), excluding special tokens
                residue_repr = token_representations[i, 1 : seq_len + 1]

                # Global pooling
                if pool_strategy == "mean":
                    seq_repr = residue_repr.mean(dim=0)
                elif pool_strategy == "cls":
                    seq_repr = token_representations[i, 0]
                elif pool_strategy == "both":
                    mean_repr = residue_repr.mean(dim=0)
                    cls_repr = token_representations[i, 0]
                    seq_repr = torch.cat([mean_repr, cls_repr], dim=0)
                else:
                    raise ValueError(f"Unknown pool_strategy: {pool_strategy}")

                batch_embeddings.append(seq_repr)

                # Save per-residue embeddings (as fp16)
                if save_residue and residue_dir is not None:
                    accession = label
                    safe_accession = accession.replace("/", "_")
                    out_path = os.path.join(residue_dir, f"{safe_accession}.pt")
                    torch.save(residue_repr.half().cpu(), out_path)

            return torch.stack(batch_embeddings)


GENERATING EMBEDDINGS FOR NEW PROTEINS

üñ•Ô∏è  Device: cuda
   GPU: Tesla T4
   Memory: 15.83 GB

üîß Loading ESM-2 model...
‚úì Model loaded: esm2_t33_650M_UR50D
  Embedding dimension: 1280
  Device: cuda


In [None]:
print("\nüîÑ Generating embeddings for new proteins...")

generator = ESM2EmbeddingGenerator(
    model=model,
    alphabet=alphabet,
    device=device,
    batch_size=8,
)

POOL_STRATEGY = "mean"  # Same as original

new_embeddings, new_metadata = generator.generate_embeddings(
    new_proteins,
    pool_strategy=POOL_STRATEGY,
    save_residue=False,
    residue_dir=RESIDUE_DIR,
)

print(f"\n‚úì Generated embeddings shape: {new_embeddings.shape}")
print(f"  Embedding dimension: {new_embeddings.shape[1]}")
print(f"  Number of proteins: {new_embeddings.shape[0]}")
print(f"  Dtype: {new_embeddings.dtype}")

"""### **Save Embeddings and Metadata**"""

# Save embeddings
emb_path = f"{EMBEDDINGS_DIR}/new_proteins_100_embeddings.pt"
torch.save(new_embeddings, emb_path)
print(f"\nüíæ Saved embeddings to: {emb_path}")

# Save metadata
meta_path = f"{EMBEDDINGS_DIR}/new_proteins_100_metadata.json"
with open(meta_path, "w") as f:
    json.dump(new_metadata, f, indent=2)
print(f"üíæ Saved metadata to: {meta_path}")

# Calculate file sizes
emb_size = os.path.getsize(emb_path) / (1024 * 1024)
meta_size = os.path.getsize(meta_path) / (1024 * 1024)
print(f"\nüìä File sizes:")
print(f"  Embeddings: {emb_size:.2f} MB")
print(f"  Metadata: {meta_size:.2f} MB")


üîÑ Generating embeddings for new proteins...


Generating embeddings:   0%|          | 0/13 [00:00<?, ?it/s]


‚úì Generated embeddings shape: torch.Size([100, 1280])
  Embedding dimension: 1280
  Number of proteins: 100
  Dtype: torch.float16

üíæ Saved embeddings to: /content/drive/MyDrive/protein-multimodal/data/new_proteins_100/embeddings/new_proteins_100_embeddings.pt
üíæ Saved metadata to: /content/drive/MyDrive/protein-multimodal/data/new_proteins_100/embeddings/new_proteins_100_metadata.json

üìä File sizes:
  Embeddings: 0.25 MB
  Metadata: 0.14 MB


In [None]:
print(f"\nüìà Embedding statistics:")
print(f"  Mean: {new_embeddings.float().mean().item():.4f}")
print(f"  Std:  {new_embeddings.float().std().item():.4f}")
print(f"  Min:  {new_embeddings.float().min().item():.4f}")
print(f"  Max:  {new_embeddings.float().max().item():.4f}")

if torch.isnan(new_embeddings).any():
    print("  ‚ö†Ô∏è  WARNING: NaN values detected in embeddings!")
if torch.isinf(new_embeddings).any():
    print("  ‚ö†Ô∏è  WARNING: Inf values detected in embeddings!")

if not torch.isnan(new_embeddings).any() and not torch.isinf(new_embeddings).any():
    print("  ‚úÖ No NaN or Inf values detected")


üìà Embedding statistics:
  Mean: -0.0009
  Std:  0.2076
  Min:  -8.5859
  Max:  1.6904
  ‚úÖ No NaN or Inf values detected


In [None]:
"""### **Verify Embeddings**"""

print("\n" + "=" * 80)
print("VERIFICATION: Loading and checking embeddings")
print("=" * 80)

# Load embeddings
loaded_emb = torch.load(emb_path, map_location="cpu")
print(f"\n‚úì Successfully loaded embeddings")
print(f"  Shape: {loaded_emb.shape}")
print(f"  Dtype: {loaded_emb.dtype}")

# Load metadata
with open(meta_path, "r") as f:
    loaded_meta = json.load(f)
print(f"\n‚úì Successfully loaded metadata")
print(f"  Entries: {len(loaded_meta)}")

# Check alignment
if loaded_emb.shape[0] == len(loaded_meta):
    print(f"\n‚úÖ Embeddings and metadata are aligned")
else:
    print(f"\n‚ùå ERROR: Mismatch between embeddings ({loaded_emb.shape[0]}) and metadata ({len(loaded_meta)})")

# Check residue-level embeddings
residue_files = os.listdir(RESIDUE_DIR)
print(f"\n‚úì Generated {len(residue_files)} residue-level embedding files")
if residue_files:
    sample_res = torch.load(os.path.join(RESIDUE_DIR, residue_files[0]))
    print(f"  Sample residue embedding shape: {sample_res.shape}")
    print(f"  Sample residue embedding dtype: {sample_res.dtype}")


VERIFICATION: Loading and checking embeddings

‚úì Successfully loaded embeddings
  Shape: torch.Size([100, 1280])
  Dtype: torch.float16

‚úì Successfully loaded metadata
  Entries: 100

‚úÖ Embeddings and metadata are aligned

‚úì Generated 0 residue-level embedding files
