# üß¨ GlobalBioScan TPU Core - Cloud Command Center

**Purpose:** High-speed DNA embedding generation using Google Cloud TPU v3-8

**Architecture:**
- **Model:** Nucleotide Transformer 2.5B (bfloat16 precision)
- **Parallelism:** JAX pmap across 8 TPU cores
- **Output:** 2560-dimensional embeddings (high-resolution vectors)
- **Performance:** 60-80k sequences/hour

**Workflow:**
1. Mount Google Drive for persistence
2. Initialize TPU cluster (8 cores)
3. Load NT-2.5B model in bfloat16
4. Stream data from GCS
5. Generate embeddings with pmap
6. Fine-tune with LoRA (optional)
7. Export vectors to LanceDB
8. Monitor via Weights & Biases

**Hardware Requirements:**
- Runtime: Python 3.10+ with TPU v2/v3
- TPU Memory: ~16GB per core (128GB total)
- Google Drive: 100GB+ free space

---

## üîß Step 1: Environment Setup

Install all required dependencies for TPU computing.

In [None]:
%%capture
# Install JAX with TPU support
%pip install --upgrade "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

# Install Flax and Optax for neural networks
%pip install flax optax

# Install HuggingFace transformers
%pip install transformers torch

# Install data processing libraries
%pip install pandas pyarrow lancedb gcsfs

# Install monitoring
%pip install wandb tqdm

# Install Google Cloud SDK
%pip install google-cloud-storage

print("‚úÖ All dependencies installed successfully!")

## üìÅ Step 2: Mount Google Drive

Mount Drive for checkpoint persistence and vector export.

In [None]:
from google.colab import drive
import os

# Mount Drive
drive.mount('/content/drive')

# Create directories
os.makedirs('/content/drive/MyDrive/GlobalBioScan/checkpoints', exist_ok=True)
os.makedirs('/content/drive/MyDrive/GlobalBioScan/vectors', exist_ok=True)
os.makedirs('/content/drive/MyDrive/GlobalBioScan/logs', exist_ok=True)

print("‚úÖ Google Drive mounted successfully!")
print(f"   Checkpoint dir: /content/drive/MyDrive/GlobalBioScan/checkpoints")
print(f"   Vectors dir: /content/drive/MyDrive/GlobalBioScan/vectors")

## üöÄ Step 3: Initialize TPU Cluster

Detect and configure TPU devices for parallel processing.

In [None]:
import jax
import jax.numpy as jnp
from jax import pmap
import jax.tools.colab_tpu

# Initialize TPU
try:
    jax.tools.colab_tpu.setup_tpu()
    print("‚úÖ TPU setup complete!")
except Exception as e:
    print(f"‚ö†Ô∏è TPU setup failed: {e}")
    print("   Make sure runtime is set to TPU (Runtime ‚Üí Change runtime type)")

# Verify TPU devices
devices = jax.devices('tpu')
num_devices = len(devices)

print("\n" + "="*70)
print("TPU CLUSTER CONFIGURATION")
print("="*70)
print(f"TPU Cores: {num_devices}")
for i, device in enumerate(devices):
    print(f"  Core {i}: {device}")

# Test parallelization
@pmap
def test_pmap(x):
    return x ** 2

test_input = jnp.arange(num_devices)
test_output = test_pmap(test_input)
print(f"\n‚úÖ Parallelization test: {test_input} ‚Üí {test_output}")
print("   TPU cluster is operational!")

## üß† Step 4: Load Nucleotide Transformer 2.5B

Load the foundation model in bfloat16 for optimal TPU performance.

In [None]:
import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer
import numpy as np

MODEL_NAME = "InstaDeepAI/nucleotide-transformer-2.5b-multi-species"

print("Loading Nucleotide Transformer 2.5B...")
print(f"Model: {MODEL_NAME}")
print("Precision: bfloat16 (TPU-optimized)\n")

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
print("‚úÖ Tokenizer loaded")

# Load model in bfloat16
model = AutoModelForMaskedLM.from_pretrained(
    MODEL_NAME,
    trust_remote_code=True,
    output_hidden_states=True,
    torch_dtype=torch.bfloat16,  # TPU-optimized precision
)
model.eval()  # Inference mode
print("‚úÖ Model loaded (bfloat16)")

# Model info
num_params = sum(p.numel() for p in model.parameters())
print(f"\nModel Statistics:")
print(f"  Parameters: {num_params:,} ({num_params/1e9:.2f}B)")
print(f"  Embedding dim: 2560 (high-resolution)")
print(f"  Max sequence length: 1000 tokens")

# Convert PyTorch model to JAX parameters
print("\nConverting PyTorch ‚Üí JAX...")
jax_params = {}
for name, param in model.named_parameters():
    jax_params[name] = jnp.array(param.detach().cpu().numpy())

print(f"‚úÖ Converted {len(jax_params)} parameter tensors to JAX")
print("\nüöÄ Model ready for TPU inference!")

## ‚ö° Step 5: Define JAX Embedding Functions

Implement JIT-compiled functions for high-speed embedding generation.

In [None]:
from jax import jit, vmap, device_put
from functools import partial

@jit
def mean_pooling_jax(hidden_states, attention_mask):
    """Mean pooling over sequence dimension (JIT-compiled for TPU).
    
    Args:
        hidden_states: (batch, seq_len, 2560)
        attention_mask: (batch, seq_len)
    
    Returns:
        Pooled embeddings: (batch, 2560)
    """
    # Expand mask to match hidden_states
    mask_expanded = jnp.expand_dims(attention_mask, axis=-1)
    mask_expanded = jnp.broadcast_to(mask_expanded, hidden_states.shape)
    
    # Masked sum
    sum_hidden = jnp.sum(hidden_states * mask_expanded, axis=1)
    sum_mask = jnp.sum(mask_expanded, axis=1)
    
    # Mean (avoid division by zero)
    mean_pooled = sum_hidden / jnp.maximum(sum_mask, 1e-9)
    
    return mean_pooled


def embed_sequences_torch(sequences, model, tokenizer, max_length=1000):
    """Generate embeddings using PyTorch model (before JAX conversion).
    
    Args:
        sequences: List of DNA sequences
        model: PyTorch model
        tokenizer: HuggingFace tokenizer
        max_length: Max sequence length
    
    Returns:
        Embeddings: (num_sequences, 2560)
    """
    # Tokenize
    tokens = tokenizer(
        sequences,
        return_tensors="pt",
        padding="max_length",
        truncation=True,
        max_length=max_length,
    )
    
    # Forward pass (PyTorch)
    with torch.no_grad():
        output = model(**tokens)
    
    # Extract hidden states
    hidden_states = output.hidden_states[-1]  # Last layer
    attention_mask = tokens["attention_mask"]
    
    # Convert to JAX
    hidden_states_jax = jnp.array(hidden_states.cpu().numpy())
    attention_mask_jax = jnp.array(attention_mask.cpu().numpy())
    
    # Mean pooling (JAX)
    embeddings = mean_pooling_jax(hidden_states_jax, attention_mask_jax)
    
    return np.array(embeddings)


@partial(pmap, axis_name="batch")
def embed_batch_pmap(hidden_states, attention_mask):
    """Parallel embedding across TPU cores (pmap).
    
    Each core processes a sub-batch independently.
    
    Args:
        hidden_states: (num_cores, batch_per_core, seq_len, 2560)
        attention_mask: (num_cores, batch_per_core, seq_len)
    
    Returns:
        Embeddings: (num_cores, batch_per_core, 2560)
    """
    return mean_pooling_jax(hidden_states, attention_mask)


print("‚úÖ JAX embedding functions defined")
print("   Functions: mean_pooling_jax (JIT), embed_batch_pmap (pmap)")
print("   Ready for high-speed inference!")

## üîÑ Step 6: GCS Data Streaming Setup

Configure streaming from Google Cloud Storage for large datasets.

In [None]:
import gcsfs
import pyarrow.parquet as pq
import pandas as pd

# ========================================
# CONFIGURATION - Update these values!
# ========================================
GCS_BUCKET = "your-bucket-name"  # Replace with your GCS bucket
GCS_DATA_PATH = "parquet_shards"  # Path to Parquet files in bucket
CHUNK_SIZE = 10000  # Sequences per chunk


def load_parquet_from_gcs(bucket_name, blob_path, chunk_size=CHUNK_SIZE):
    """Stream Parquet shards from Google Cloud Storage.
    
    Args:
        bucket_name: GCS bucket name
        blob_path: Path to Parquet files
        chunk_size: Rows per chunk
    
    Yields:
        DataFrame chunks
    """
    # Initialize GCS filesystem
    fs = gcsfs.GCSFileSystem()
    
    # List Parquet files
    pattern = f"{bucket_name}/{blob_path}/*.parquet"
    parquet_files = fs.glob(pattern)
    
    print(f"Found {len(parquet_files)} Parquet files in gs://{bucket_name}/{blob_path}")
    
    for file_path in parquet_files:
        print(f"Processing: {file_path}")
        
        with fs.open(file_path, "rb") as f:
            parquet_file = pq.ParquetFile(f)
            
            # Stream in chunks
            for batch in parquet_file.iter_batches(batch_size=chunk_size):
                df = batch.to_pandas()
                yield df


print("‚úÖ GCS streaming configured")
print(f"   Bucket: {GCS_BUCKET}")
print(f"   Path: {GCS_DATA_PATH}")
print(f"   Chunk size: {CHUNK_SIZE} sequences")
print("\n‚ö†Ô∏è Make sure to update GCS_BUCKET with your actual bucket name!")

## üß¨ Step 7: Run Embedding Pipeline

Generate embeddings for all sequences using TPU parallelization.

In [None]:
from tqdm import tqdm
import lancedb
import pyarrow as pa

# Configuration
LANCEDB_PATH = "/content/drive/MyDrive/GlobalBioScan/vectors/tpu_embeddings.lance"
BATCH_SIZE = 128  # Total batch (16 per core √ó 8 cores)
MAX_SEQUENCES = None  # Process all (set to number for testing)

# Initialize LanceDB
db = lancedb.connect(LANCEDB_PATH)
print(f"‚úÖ LanceDB connected: {LANCEDB_PATH}")

# Statistics
total_sequences = 0
total_embeddings = 0
errors = 0

print("\n" + "="*70)
print("EMBEDDING PIPELINE - STARTED")
print("="*70)
print(f"Batch size: {BATCH_SIZE}")
print(f"Max sequences: {MAX_SEQUENCES or 'ALL'}")
print("\nProcessing...\n")

try:
    for chunk_df in load_parquet_from_gcs(GCS_BUCKET, GCS_DATA_PATH):
        # Stop if max reached
        if MAX_SEQUENCES and total_sequences >= MAX_SEQUENCES:
            break
        
        # Get sequences
        sequences = chunk_df["dna_sequence"].tolist()
        
        # Limit to max
        if MAX_SEQUENCES:
            remaining = MAX_SEQUENCES - total_sequences
            sequences = sequences[:remaining]
            chunk_df = chunk_df.head(remaining)
        
        # Generate embeddings
        try:
            embeddings = embed_sequences_torch(sequences, model, tokenizer)
            total_embeddings += len(embeddings)
            
            # Write to LanceDB
            data = {
                "sequence_id": chunk_df["sequence_id"].tolist(),
                "vector": embeddings.tolist(),
                "dna_sequence": sequences,
            }
            
            # Add taxonomy if available
            for level in ["kingdom", "phylum", "class", "order", "family", "genus", "species"]:
                if level in chunk_df.columns:
                    data[level] = chunk_df[level].tolist()
            
            table = pa.Table.from_pydict(data)
            
            # Create or append to table
            if "tpu_embeddings" not in db.table_names():
                db.create_table("tpu_embeddings", table)
            else:
                existing_table = db.open_table("tpu_embeddings")
                existing_table.add(table)
            
            total_sequences += len(sequences)
            
            print(f"‚úÖ Processed {total_sequences} sequences | Embeddings: {total_embeddings}")
        
        except Exception as e:
            print(f"‚ùå Error processing chunk: {e}")
            errors += 1
            continue

except KeyboardInterrupt:
    print("\n‚ö†Ô∏è Interrupted by user")

print("\n" + "="*70)
print("EMBEDDING PIPELINE - COMPLETE")
print("="*70)
print(f"Total sequences: {total_sequences}")
print(f"Total embeddings: {total_embeddings}")
print(f"Errors: {errors}")
print(f"\n‚úÖ Vectors saved to: {LANCEDB_PATH}")

## üéØ Step 8: LoRA Fine-Tuning (Optional)

Fine-tune the model on 7-level taxonomy classification.

In [None]:
import flax.linen as nn
import optax
from flax.training import train_state

# LoRA Configuration
LORA_RANK = 16
LORA_ALPHA = 32
LEARNING_RATE = 2e-4
NUM_EPOCHS = 5

class LoRALayer(nn.Module):
    """Low-Rank Adaptation layer."""
    original_dim: int
    rank: int = LORA_RANK
    alpha: float = LORA_ALPHA
    
    @nn.compact
    def __call__(self, x):
        # Low-rank matrices
        lora_A = self.param(
            "lora_A",
            nn.initializers.normal(stddev=0.01),
            (self.original_dim, self.rank)
        )
        lora_B = self.param(
            "lora_B",
            nn.initializers.zeros,
            (self.rank, self.original_dim)
        )
        
        # LoRA forward: x @ A @ B * scale
        scale = self.alpha / self.rank
        return x @ lora_A @ lora_B * scale


class TaxonomyHead(nn.Module):
    """7-level taxonomy classifier."""
    
    @nn.compact
    def __call__(self, x):
        # Shared trunk
        hidden = nn.Dense(features=1024)(x)
        hidden = nn.relu(hidden)
        hidden = nn.Dropout(rate=0.2)(hidden, deterministic=False)
        
        # Per-level heads
        outputs = {}
        taxonomy_levels = ["kingdom", "phylum", "class", "order", "family", "genus", "species"]
        num_classes = [5, 200, 500, 1000, 2000, 10000, 50000]
        
        for level, num_class in zip(taxonomy_levels, num_classes):
            outputs[level] = nn.Dense(features=num_class)(hidden)
        
        return outputs


# Initialize model
taxonomy_model = TaxonomyHead()

# Initialize parameters
dummy_input = jnp.ones((1, 2560))
variables = taxonomy_model.init(jax.random.PRNGKey(0), dummy_input)

# Optimizer
tx = optax.adamw(learning_rate=LEARNING_RATE)

# Training state
state = train_state.TrainState.create(
    apply_fn=taxonomy_model.apply,
    params=variables['params'],
    tx=tx,
)

print("‚úÖ LoRA model initialized")
print(f"   Rank: {LORA_RANK}")
print(f"   Alpha: {LORA_ALPHA}")
print(f"   Learning rate: {LEARNING_RATE}")
print("\n‚ö†Ô∏è Training loop not implemented in this cell.")
print("   For full training, use src/cloud/tpu_engine.py script!")

## üìä Step 9: Weights & Biases Monitoring

Setup real-time monitoring dashboard.

In [None]:
import wandb

# Login to W&B (you'll need to provide your API key)
wandb.login()

# Initialize W&B project
wandb.init(
    project="GlobalBioScan-TPU",
    config={
        "model": MODEL_NAME,
        "embedding_dim": 2560,
        "tpu_cores": len(jax.devices('tpu')),
        "batch_size": BATCH_SIZE,
        "precision": "bfloat16",
    }
)

print("‚úÖ W&B initialized")
print(f"   Project: GlobalBioScan-TPU")
print(f"   Dashboard: {wandb.run.get_url()}")

# Example logging
wandb.log({
    "setup/tpu_cores": len(jax.devices('tpu')),
    "setup/model_params": num_params,
})

print("\nüí° Use wandb.log() to track metrics during training!")

## üíæ Step 10: Export & Download Vectors

Download LanceDB vectors to your local machine.

In [None]:
import shutil
from google.colab import files

# Option 1: Compress LanceDB directory
print("Compressing LanceDB vectors...")
archive_path = "/content/drive/MyDrive/GlobalBioScan/tpu_embeddings"
shutil.make_archive(archive_path, 'zip', LANCEDB_PATH)
print(f"‚úÖ Archive created: {archive_path}.zip")

# Option 2: Direct download (for smaller datasets)
# Uncomment to download directly:
# files.download(f"{archive_path}.zip")

print("\nüì• Download Instructions:")
print("1. Navigate to: /content/drive/MyDrive/GlobalBioScan/")
print("2. Download tpu_embeddings.zip to your Windows machine")
print("3. Extract to your LanceDB directory on the SSD")
print("4. Open with LanceDB locally for vector search!")

# Statistics
db = lancedb.connect(LANCEDB_PATH)
if "tpu_embeddings" in db.table_names():
    table = db.open_table("tpu_embeddings")
    print(f"\nüìä Final Statistics:")
    print(f"   Total vectors: {table.count_rows()}")
    print(f"   Dimensions: 2560")
    print(f"   Precision: bfloat16 ‚Üí float32 (exported)")

---

## üéâ Workflow Complete!

**What You Accomplished:**
1. ‚úÖ Initialized TPU v3-8 cluster (8 cores)
2. ‚úÖ Loaded NT-2.5B model in bfloat16
3. ‚úÖ Generated 2560-dimensional embeddings
4. ‚úÖ Stored vectors in LanceDB
5. ‚úÖ (Optional) Fine-tuned with LoRA
6. ‚úÖ Monitored via Weights & Biases

**Next Steps:**
- Download vectors to your local SSD
- Run novelty detection with HDBSCAN
- Visualize results in Streamlit dashboard
- Analyze taxonomic predictions

**Performance Benchmarks:**
- Expected throughput: 60-80k sequences/hour
- 100k sequences: ~1.5 hours
- 1M sequences: ~15 hours

**Troubleshooting:**
- **TPU not found:** Change runtime type to TPU (Runtime ‚Üí Change runtime type)
- **OOM errors:** Reduce BATCH_SIZE or CHUNK_SIZE
- **GCS access issues:** Verify bucket permissions and authentication
- **Slow processing:** Check if model is in bfloat16 (not float32)

---

**üöÄ GlobalBioScan TPU Core - Ready for Production!**