In [None]:
# =========================
# Step 0: Mount Google Drive
# =========================
from google.colab import drive
drive.mount("/content/drive")

# =========================
# Step 1: Install dependencies
# =========================
!apt-get install -y prodigal > /dev/null
!pip install fair-esm biopython --quiet

# =========================
# Step 2: Imports
# =========================
import os
import re
import json
import subprocess
import torch
import esm
import numpy as np
from Bio import SeqIO
from tqdm import tqdm
import gc
from ipywidgets import widgets
from IPython.display import display

# =========================
# Step 3: Colab input boxes
# =========================
input_fasta_dir_widget = widgets.Text(
    value='/content/drive/MyDrive/Average_Structural_ID/ASI_Fasta_Files/ASI_Fasta_Files_Trial',
    description='Input FASTA Dir:',
    layout=widgets.Layout(width='90%')
)

esm_root_dir_widget = widgets.Text(
    value='/content/drive/MyDrive/Average_Structural_ID/ESM2_npz_outputs',
    description='Output NPZ Dir:',
    layout=widgets.Layout(width='90%')
)

progress_json_widget = widgets.Text(
    value='/content/drive/MyDrive/Average_Structural_ID/esm2_progress.json',
    description='Progress JSON:',
    layout=widgets.Layout(width='90%')
)

display(input_fasta_dir_widget, esm_root_dir_widget, progress_json_widget)

input("Press Enter after confirming paths above...")

input_fasta_dir = input_fasta_dir_widget.value
esm_root_dir = esm_root_dir_widget.value
progress_json = progress_json_widget.value

os.makedirs(esm_root_dir, exist_ok=True)

# =========================
# Step 4: Load / initialize progress
# =========================
if os.path.exists(progress_json):
    with open(progress_json) as f:
        progress = json.load(f)
else:
    progress = {"completed": [], "failed": []}

# =========================
# Step 5: Load ESM2 model
# =========================
model_name = "esm2_t6_8M_UR50D"
MAX_LEN = 2000
BATCH_SIZE = 12  # adjust for GPU memory

print("Loading ESM2 model...")
model, alphabet = esm.pretrained.load_model_and_alphabet(model_name)
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
batch_converter = alphabet.get_batch_converter()

# =========================
# Step 6: Helper functions
# =========================
def clean_sequence(seq: str) -> str:
    allowed = set("ACDEFGHIKLMNPQRSTVWY")
    return "".join(aa if aa in allowed else "X" for aa in seq)

def read_faa(faa_path):
    return [(record.id, clean_sequence(str(record.seq))) for record in SeqIO.parse(faa_path, "fasta")]

def save_progress():
    with open(progress_json, "w") as f:
        json.dump(progress, f, indent=2)

# =========================
# Step 7: Process genomes
# =========================
fasta_files = sorted(f for f in os.listdir(input_fasta_dir) if f.lower().endswith((".fasta", ".fa", ".fna")))
print(f"Found {len(fasta_files)} FASTA files")

for fasta_name in fasta_files:
    if fasta_name in progress["completed"]:
        print(f"‚è≠Ô∏è Skipping completed genome ‚Üí {fasta_name}")
        continue

    print(f"\nüöÄ Processing genome ‚Üí {fasta_name}")
    genome_fasta = os.path.join(input_fasta_dir, fasta_name)

    try:
        base_name = re.sub(r"[^\w.-]", "_", os.path.splitext(fasta_name)[0])
        faa_file = os.path.join(input_fasta_dir, f"{base_name}.faa")
        gff_file = os.path.join(input_fasta_dir, f"{base_name}.gff")
        genome_npz_file = os.path.join(esm_root_dir, f"{base_name}_esm2.npz")

        print("Running Prodigal...")
        subprocess.run(["prodigal", "-i", genome_fasta, "-a", faa_file, "-o", gff_file, "-p", "single"], check=True)
        if not os.path.exists(faa_file):
            raise FileNotFoundError(f"Prodigal did not create {faa_file}")

        sequences = read_faa(faa_file)
        print(f"Embedding {len(sequences)} proteins...")

        all_embeddings = {}
        for i in tqdm(range(0, len(sequences), BATCH_SIZE), desc="Batches"):
            batch_seqs = sequences[i:i+BATCH_SIZE]
            batch_filtered = [(pid, seq) for pid, seq in batch_seqs if len(seq) <= MAX_LEN]
            if not batch_filtered:
                continue

            batch_labels, batch_strs = zip(*batch_filtered)
            _, _, batch_tokens = batch_converter(batch_filtered)
            batch_tokens = batch_tokens.to(device)

            with torch.no_grad():
                results = model(batch_tokens, repr_layers=[6], return_contacts=False)
            token_embeddings = results["representations"][6]

            for idx, pid in enumerate(batch_labels):
                seq_len = len(batch_strs[idx])
                embedding = token_embeddings[idx, 1:seq_len+1].mean(0).cpu().numpy()
                all_embeddings[pid] = embedding

            del batch_tokens, results, token_embeddings
            torch.cuda.empty_cache()
            gc.collect()

        np.savez_compressed(genome_npz_file, **all_embeddings)

        for tmp in (faa_file, gff_file):
            if os.path.exists(tmp):
                os.remove(tmp)

        progress["completed"].append(fasta_name)
        save_progress()
        print(f"‚úÖ Finished genome ‚Üí {fasta_name}")

    except Exception as e:
        print(f"‚ùå Failed genome ‚Üí {fasta_name}")
        print(e)
        progress["failed"].append(fasta_name)
        save_progress()

print("\nüéâ All available genomes processed!")
