# DeepBio-Scan: Large-Scale Atlas Seeding (N=100,000)
## Phase 1: Reference Atlas Generation (Optimized)

**Personas Active:**
- `@Embedder-ML` (Model Logic & Inference)
- `@Data-Ops` (Data Pipeline & Parquet Export)

**Hardware Target:** Google Colab T4 GPU (or better)

# DeepBio-Scan: Large-Scale Atlas Seeding (N=100,000)
## Phase 1: Reference Atlas Generation

**Personas Active:**
- `@Embedder-ML` (Model Logic & Inference)
- `@Data-Ops` (Data Pipeline & Parquet Export)

**Hardware Target:** Google Colab T4 GPU (or better)

In [None]:
# @Data-Ops: Dependency Setup
!pip uninstall -y torch_xla
!pip install --upgrade transformers==4.40.2 pandas pyarrow duckdb lancedb accelerate biopython tqdm

In [None]:
import os
import time
import torch
import pandas as pd
import numpy as np
from Bio import Entrez, SeqIO
from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoConfig
from tqdm.auto import tqdm

# @Embedder-ML: GPU Acceleration Check
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Running on: {device.upper()}")
if device == "cpu":
    print("WARNING: GPU not detected. Embedding will be slow.")

In [None]:
# @Data-Ops: Step 1 - High-Speed Batch-fetching logic using Bio.Entrez (FASTA Mode)
Entrez.email = "data-ops@deepbio.scan"  # Replace with your email
CHECKPOINT_FILE = "checkpoint_atlas.parquet"

def fetch_taxonomy_metadata(tax_ids):
    """
    @Data-Ops: Parallel Metadata Retrieval
    Fetches taxonomy details for a list of TaxIDs.
    """
    if not tax_ids:
        return {}
    
    print(f"Fetching taxonomy for {len(tax_ids)} unique TaxIDs...")
    tax_dict = {}
    try:
        # Fetch taxonomy records
        handle = Entrez.efetch(db="taxonomy", id=",".join(tax_ids), retmode="xml")
        records = Entrez.read(handle)
        handle.close()
        
        for record in records:
            tax_id = record.get("TaxId")
            lineage_ex = record.get("LineageEx", [])
            
            phylum, class_name, order, family, genus = "Unknown", "Unknown", "Unknown", "Unknown", "Unknown"
            
            for taxon in lineage_ex:
                rank = taxon.get("Rank")
                name = taxon.get("ScientificName")
                if rank == "phylum": phylum = name
                elif rank == "class": class_name = name
                elif rank == "order": order = name
                elif rank == "family": family = name
                elif rank == "genus": genus = name
                
            tax_dict[tax_id] = {
                "Phylum": phylum,
                "Class": class_name,
                "Order": order,
                "Family": family,
                "Genus": genus
            }
    except Exception as e:
        print(f"Error fetching taxonomy: {e}")
    
    return tax_dict

def fetch_marine_eukaryotes(target_count=100000, batch_size=500):
    print(f"Fetching {target_count} marine eukaryotic sequences...")
    
    # Search query for marine eukaryotes (e.g., 18S/COI)
    search_query = "eukaryota[Organism] AND (marine[All Fields] OR ocean[All Fields]) AND (18S[All Fields] OR COI[All Fields])"
    
    handle = Entrez.esearch(db="nucleotide", term=search_query, retmax=target_count, usehistory="y")
    record = Entrez.read(handle)
    handle.close()
    
    webenv = record["WebEnv"]
    query_key = record["QueryKey"]
    total_found = int(record["Count"])
    print(f"Found {total_found} sequences matching query. Fetching up to {target_count}...")
    
    # @Data-Ops: Robustness - Resume Logic & Memory Optimization
    os.makedirs("checkpoints", exist_ok=True)
    
    # Find existing checkpoints to determine start_index
    existing_files = [f for f in os.listdir("checkpoints") if f.startswith("batch_") and f.endswith(".parquet")]
    if existing_files:
        # Extract start indices from filenames
        indices = [int(f.split('_')[1].split('.')[0]) for f in existing_files]
        start_index = max(indices) + batch_size
        print(f"Found {len(existing_files)} existing batches. Resuming from index {start_index}...")
    else:
        start_index = 0
    
    for start in range(start_index, min(target_count, total_found), batch_size):
        print(f"Fetching batch {start} to {start + batch_size}...")
        try:
            # @Data-Ops: High-Speed Fetching (FASTA Mode)
            fetch_handle = Entrez.efetch(
                db="nucleotide", 
                rettype="fasta", 
                retmode="text", 
                retstart=start, 
                retmax=batch_size, 
                webenv=webenv, 
                query_key=query_key
            )
            
            batch_records = list(SeqIO.parse(fetch_handle, "fasta"))
            fetch_handle.close()
            
            # Extract Accessions to fetch TaxIDs (FASTA doesn't have TaxIDs directly)
            # @Data-Ops: Memory Optimization - Use esummary instead of full XML efetch
            # Fetch Document Summaries to get TaxIDs (Lightweight)
            docsum_handle = Entrez.esummary(
                db="nucleotide", 
                retstart=start, 
                retmax=batch_size, 
                webenv=webenv, 
                query_key=query_key
            )
            docsums = Entrez.read(docsum_handle)
            docsum_handle.close()
            
            accession_to_taxid = {}
            for docsum in docsums:
                acc_version = docsum.get("AccessionVersion", "")
                acc = docsum.get("Caption", "")
                tax_id = str(docsum.get("TaxId", "Unknown"))
                
                if acc_version:
                    accession_to_taxid[acc_version.split('.')[0]] = tax_id
                if acc:
                    accession_to_taxid[acc.split('.')[0]] = tax_id
            
            # Get unique TaxIDs and fetch taxonomy
            unique_taxids = list(set([tid for tid in accession_to_taxid.values() if tid != "Unknown"]))
            taxonomy_metadata = fetch_taxonomy_metadata(unique_taxids)
            
            batch_data = []
            for seq_record in batch_records:
                accession = seq_record.id.split('.')[0]
                sequence = str(seq_record.seq).upper()
                
                # @Data-Ops: Biological Filter (200bp - 2000bp to prevent OOM)
                if len(sequence) < 200 or len(sequence) > 2000:
                    continue
                
                # Try to extract scientific name from description (usually after the accession)
                desc_parts = seq_record.description.split(' ', 1)
                scientific_name = desc_parts[1] if len(desc_parts) > 1 else "Unknown"
                
                tax_id = accession_to_taxid.get(accession, "Unknown")
                tax_info = taxonomy_metadata.get(tax_id, {})
                
                batch_data.append({
                    "AccessionID": accession,
                    "ScientificName": scientific_name,
                    "TaxID": tax_id,
                    "Phylum": tax_info.get("Phylum", "Unknown"),
                    "Class": tax_info.get("Class", "Unknown"),
                    "Order": tax_info.get("Order", "Unknown"),
                    "Family": tax_info.get("Family", "Unknown"),
                    "Genus": tax_info.get("Genus", "Unknown"),
                    "Sequence": sequence,
                    "Quality_Check": True # Already filtered for >200bp
                })
            
            # @Data-Ops: OOM Prevention - Save ONLY this batch
            if batch_data:
                df_batch = pd.DataFrame(batch_data)
                batch_file = f"checkpoints/batch_{start}.parquet"
                df_batch.to_parquet(batch_file, engine="pyarrow")
                print(f"Batch saved: {batch_file} ({len(df_batch)} records).")
            else:
                print(f"No valid records in batch {start}.")
                
        except Exception as e:
            print(f"Error fetching batch {start}: {e}")
            time.sleep(5) # Backoff
            
    # Return the list of batch files instead of loading them all into memory
    print("Fetching complete. Batches are saved in the 'checkpoints' directory.")
    all_files = [os.path.join("checkpoints", f) for f in os.listdir("checkpoints") if f.startswith("batch_") and f.endswith(".parquet")]
    print(f"Total batches: {len(all_files)}")
    return all_files

# Execute fetching
batch_files = fetch_marine_eukaryotes(target_count=100000, batch_size=500)
print(f"Dataset ready. {len(batch_files)} batches saved to disk.")

In [None]:
# @Embedder-ML: Step 2 - Neural Embedding Pipeline
class LargeScaleEmbedder:
    def __init__(self, model_name="InstaDeepAI/nucleotide-transformer-v2-50m-multi-species"):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        print(f"Initializing Model: {model_name} on {self.device}...")
        
        # Load Config and Monkey-Patch
        config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
        
        # @Embedder-ML: Monkey-patch config.intermediate_size = 4096
        config.intermediate_size = 4096
        print(f"Monkey-patched intermediate_size to: {config.intermediate_size}")
        
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
        
        # @Embedder-ML: Reverted to AutoModelForMaskedLM because the custom model architecture requires it
        self.model = AutoModelForMaskedLM.from_pretrained(
            model_name, 
            config=config,
            trust_remote_code=True,
            ignore_mismatched_sizes=True
        ).to(self.device)
        
        self.model.eval()
        print("Model successfully loaded.")

    def embedding_generator(self, sequences, batch_size=8):
        """
        Generator-based pipeline with batch_size=8 to prevent CUDA OOM.
        Yields float32 vectors of exactly 768 dimensions.
        """
        for i in tqdm(range(0, len(sequences), batch_size), desc="Embedding Batches"):
            batch = sequences[i:i+batch_size]
            
            # Clean sequences and FORCE truncation to 1000bp to guarantee memory safety
            batch = [seq.upper().replace("\n", "").replace("\r", "").replace("N", "A")[:1000] for seq in batch]
            
            inputs = self.tokenizer(
                batch, 
                return_tensors="pt", 
                padding=True, 
                truncation=True, 
                max_length=1000
            ).to(self.device)
            
            with torch.no_grad():
                # Use output_hidden_states=True to get the embeddings from the MaskedLM model
                outputs = self.model(**inputs, output_hidden_states=True)
                
                # Mean Pooling
                last_hidden_state = outputs.hidden_states[-1]
                attention_mask = inputs["attention_mask"].unsqueeze(-1).expand(last_hidden_state.size()).float()
                
                sum_embeddings = torch.sum(last_hidden_state * attention_mask, 1)
                sum_mask = torch.clamp(attention_mask.sum(1), min=1e-9)
                mean_embeddings = sum_embeddings / sum_mask
                
                # The 50m model has hidden_size=512. We need exactly 768 dimensions.
                # We will pad with zeros to reach 768 dimensions.
                current_dim = mean_embeddings.shape[1]
                target_dim = 768
                
                if current_dim < target_dim:
                    padding = torch.zeros((mean_embeddings.shape[0], target_dim - current_dim), device=self.device)
                    mean_embeddings = torch.cat([mean_embeddings, padding], dim=1)
                elif current_dim > target_dim:
                    mean_embeddings = mean_embeddings[:, :target_dim]
                
                # Extract to CPU immediately
                result = mean_embeddings.cpu().numpy().astype(np.float32)
            
            # @Embedder-ML: Clear CUDA cache BEFORE yielding to prevent fragmentation OOM
            del inputs
            del outputs
            del last_hidden_state
            del attention_mask
            del sum_embeddings
            del sum_mask
            del mean_embeddings
            torch.cuda.empty_cache()
            
            yield result

# Initialize Embedder
embedder = LargeScaleEmbedder()

In [None]:
# @Data-Ops: Step 3 - Process Batches and Export
import pyarrow as pa
import pyarrow.parquet as pq
import gc

print("Starting large-scale embedding generation batch-by-batch...")

# Get all checkpoint files and sort them
batch_files = [os.path.join("checkpoints", f) for f in os.listdir("checkpoints") if f.startswith("batch_") and f.endswith(".parquet")]
batch_files.sort(key=lambda x: int(x.split('_')[1].split('.')[0]))

output_file = "reference_atlas_100k.parquet"
writer = None
total_records = 0

for batch_file in tqdm(batch_files, desc="Processing Checkpoints"):
    df_batch = pd.read_parquet(batch_file)
    
    if df_batch.empty:
        continue
        
    # Generate embeddings for this batch
    all_vectors = []
    for batch_vectors in embedder.embedding_generator(df_batch["Sequence"].tolist(), batch_size=8):
        all_vectors.append(batch_vectors)
        
    final_vectors = np.concatenate(all_vectors, axis=0)
    
    # Add vectors to dataframe
    df_batch["Vector"] = list(final_vectors.astype(np.float32))
    
    # Convert to PyArrow Table
    table = pa.Table.from_pandas(df_batch)
    
    # Initialize writer with the schema of the first batch
    if writer is None:
        writer = pq.ParquetWriter(output_file, table.schema)
        
    writer.write_table(table)
    total_records += len(df_batch)
    
    # Free memory aggressively
    del df_batch
    del all_vectors
    del final_vectors
    del table
    gc.collect()

if writer:
    writer.close()

print(f"SUCCESS: Atlas saved to {output_file}.")
print(f"Total records processed and saved: {total_records}")
print("Ready for LanceDB ingestion.")