# VSim - Data Collection & Model Training
## One-Click Setup: Collect Data + Train Model

This notebook will:
1. Install all dependencies
2. Set up the VSim environment
3. Collect training data from NCBI
4. Train the viability prediction model
5. Save the trained model

**Just run all cells in order!**


## Step 1: Install Dependencies


In [None]:
# Install all required packages
!pip install -q biopython numpy pandas scipy scikit-learn torch torchvision torchaudio pyyaml requests aiohttp tqdm backoff matplotlib


## Step 2: Upload Project Files


In [None]:
from google.colab import files
import os
import zipfile
import subprocess
from pathlib import Path

# Check if VSim already exists (e.g., from previous run or Drive)
if Path('/content/VSim').exists() and any(Path('/content/VSim').iterdir()):
    print("✓ VSim directory already exists, skipping upload")
    print(f"  Location: /content/VSim")
else:
    # Option 1: Clone from GitHub (if you have a repo)
    USE_GITHUB = False  # Set to True and provide your repo URL below
    GITHUB_REPO = "your-username/Project-VSim"  # Replace with your GitHub repo
    
    if USE_GITHUB and GITHUB_REPO != "your-username/Project-VSim":
        print(f"Cloning from GitHub: {GITHUB_REPO}")
        repo_url = f"https://github.com/{GITHUB_REPO}.git"
        try:
            subprocess.run(['git', 'clone', repo_url, '/content/VSim'], check=True)
            print("✓ Cloned from GitHub")
        except subprocess.CalledProcessError as e:
            print(f"⚠ Failed to clone from GitHub: {e}")
            print("  Please use the ZIP upload method instead")
    else:
        # Option 2: Upload ZIP file
        print("Please upload the Project-VSim folder as a ZIP file:")
        print("1. Zip your Project-VSim folder")
        print("2. Click 'Choose Files' below")
        print("3. Select the ZIP file")
        print()
        
        uploaded = files.upload()
        
        # Extract the zip file
        extracted = False
        for filename in uploaded.keys():
            if filename.endswith('.zip'):
                print(f"\nExtracting {filename}...")
                with zipfile.ZipFile(filename, 'r') as zip_ref:
                    zip_ref.extractall('/content')
                extracted = True
                break
            else:
                print(f"⚠ {filename} is not a ZIP file. Please upload a ZIP file.")
        
        # Find the Project-VSim directory
        if Path('/content/Project-VSim').exists():
            os.rename('/content/Project-VSim', '/content/VSim')
            print("✓ Extraction complete!")
        elif Path('/content/VSim').exists():
            print("✓ VSim directory found!")
        elif extracted:
            # Look for any directory that might be it
            dirs = [d for d in os.listdir('/content') if os.path.isdir(f'/content/{d}') and ('VSim' in d or 'Project' in d)]
            if dirs:
                os.rename(f'/content/{dirs[0]}', '/content/VSim')
                print("✓ Extraction complete!")
            else:
                print("⚠ Could not find Project-VSim directory after extraction")
                print("  Please check the ZIP file contents")


## Step 3: Setup Environment


In [None]:
import sys
import os
from pathlib import Path

# Change to VSim directory
if Path('/content/VSim').exists():
    os.chdir('/content/VSim')
    sys.path.insert(0, '/content/VSim')
    print(f"✓ Working directory: {os.getcwd()}")
else:
    print("⚠ VSim directory not found. Please upload the Project-VSim ZIP file in Step 2.")
    print("Current directory contents:")
    for item in os.listdir('/content'):
        print(f"  - {item}")


In [None]:
# Check GPU availability
import torch

if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"✓ GPU available: {torch.cuda.get_device_name(0)}")
    print(f"✓ GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    device = torch.device('cpu')
    print("⚠ GPU not available, using CPU")
    print("⚠ For faster training, enable GPU: Runtime → Change runtime type → GPU")

# Set environment variables
os.environ['CUDA_VISIBLE_DEVICES'] = '0'


## Step 4: Configure Data Collection & Training


In [None]:
# ===== MAXIMUM QUALITY CONFIGURATION =====
# Optimized for maximum model quality (training time not a concern)

# Data Collection Settings
EMAIL = "anton.valov05@gmail.com"  # Your email for NCBI (required)
API_KEY = "d7e5c7978697a8c4284af0fc71ce1a2b9808"  # NCBI API key (optional, speeds up downloads)

# Training Data Settings - MAXIMUM QUALITY
# Using full 500K dataset for best results
TOTAL_TARGET = 500000  # Full dataset (will take ~24-48 hours for data collection)
# For testing, you can use smaller numbers:
# TOTAL_TARGET = 10000  # Medium dataset (will take ~1-3 hours)
# TOTAL_TARGET = 50000  # Large dataset (will take ~5-10 hours)

# Model Architecture Settings - ENHANCED
INPUT_DIM = 1024  # Increased from 512 (larger feature space)
HIDDEN_DIM = 512  # Increased from 256 (more capacity)
NUM_LAYERS = 8  # Increased from 4 (deeper network)
NUM_HEADS = 16  # Increased from 8 (more attention heads)

# Training Settings - MAXIMUM QUALITY
EPOCHS = 200  # Increased from 20 (more training for better convergence)
BATCH_SIZE = 64  # Increased from 32 (larger batches, adjust based on GPU memory)
LEARNING_RATE = 1e-4  # Learning rate
WEIGHT_DECAY = 1e-5  # Weight decay for regularization
WARMUP_EPOCHS = 10  # Warmup epochs for learning rate scheduling
EARLY_STOPPING_PATIENCE = 20  # Early stopping patience (epochs without improvement)
GRADIENT_CLIP_VAL = 1.0  # Gradient clipping value

# Advanced Features
USE_FOCAL_LOSS = True  # Use focal loss for better handling of class imbalance
USE_MIXED_PRECISION = True  # Use mixed precision training (faster, uses less memory)

print("="*70)
print("MAXIMUM QUALITY CONFIGURATION")
print("="*70)
print(f"Email: {EMAIL}")
print(f"API Key: {'Provided' if API_KEY else 'Not provided'}")
print(f"\nData Collection:")
print(f"  Total Target Genomes: {TOTAL_TARGET:,}")
print(f"\nModel Architecture (Enhanced):")
print(f"  Input Dimension: {INPUT_DIM}")
print(f"  Hidden Dimension: {HIDDEN_DIM}")
print(f"  Number of Layers: {NUM_LAYERS}")
print(f"  Number of Attention Heads: {NUM_HEADS}")
print(f"\nTraining Settings:")
print(f"  Epochs: {EPOCHS}")
print(f"  Batch Size: {BATCH_SIZE}")
print(f"  Learning Rate: {LEARNING_RATE}")
print(f"  Weight Decay: {WEIGHT_DECAY}")
print(f"  Warmup Epochs: {WARMUP_EPOCHS}")
print(f"  Early Stopping Patience: {EARLY_STOPPING_PATIENCE}")
print(f"  Gradient Clipping: {GRADIENT_CLIP_VAL}")
print(f"\nAdvanced Features:")
print(f"  Focal Loss: {USE_FOCAL_LOSS}")
print(f"  Mixed Precision: {USE_MIXED_PRECISION}")
print("="*70)
print("\n⚠ NOTE: This configuration is optimized for MAXIMUM QUALITY")
print("   Training will take significantly longer but produce the best model")
print("="*70)


## Step 5: Collect Training Data


In [None]:
import asyncio
import logging
import sys
from pathlib import Path

# Setup logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger(__name__)

# Add project root to path
sys.path.insert(0, str(Path.cwd()))

from src.vlab.data.collector import DataCollector

print("="*70)
print("STARTING DATA COLLECTION")
print("="*70)
print(f"This will collect {TOTAL_TARGET:,} viral genomes from NCBI")
print("This may take a while depending on your target size...")
print("="*70)
print()

# Create collector
collector = DataCollector(email=EMAIL, api_key=API_KEY)

# Collect data
try:
    results = asyncio.run(collector.collect_all_data(
        max_viable=None,  # Auto-calculate for balanced dataset
        num_synthetic_non_viable=None,  # Auto-calculate
        num_mutated_non_viable=None,  # Auto-calculate
        total_target=TOTAL_TARGET
    ))
    
    # Get statistics
    stats = collector.get_data_statistics()
    
    # Print summary
    print("\n" + "="*70)
    print("DATA COLLECTION COMPLETE!")
    print("="*70)
    print(f"Training Data:")
    print(f"  Viable: {stats['train_viable']:,}")
    print(f"  Non-viable: {stats['train_non_viable']:,}")
    print(f"  Total: {stats['total_train']:,}")
    print(f"\nValidation Data:")
    print(f"  Viable: {stats['val_viable']:,}")
    print(f"  Non-viable: {stats['val_non_viable']:,}")
    print(f"  Total: {stats['total_val']:,}")
    print(f"\nGrand Total: {stats['total']:,} genomes")
    print(f"\nData location: data/training/")
    print("="*70)
    
except KeyboardInterrupt:
    logger.error("\nData collection interrupted by user")
    raise
except Exception as e:
    logger.error(f"Data collection failed: {e}", exc_info=True)
    raise


## Step 6: Train the Model


In [None]:
import torch
import logging
from pathlib import Path
import sys

# Setup logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

# Add project root to path
sys.path.insert(0, str(Path.cwd()))

from src.vlab.training.viability_trainer import ViabilityTrainer, collect_training_data
from src.vlab.core.config import VLabConfig

print("="*70)
print("STARTING MODEL TRAINING")
print("="*70)
print(f"Epochs: {EPOCHS}")
print(f"Batch Size: {BATCH_SIZE}")
print(f"Learning Rate: {LEARNING_RATE}")
print(f"Device: {device}")
print("="*70)
print()

# Load configuration
config = VLabConfig()
config.use_gpu = torch.cuda.is_available()
config.gpu_id = 0
config.models_dir = Path('models')
config.models_dir.mkdir(parents=True, exist_ok=True)

# Collect training data
data_dir = Path('data/training')
print(f"Loading data from: {data_dir}")

try:
    train_annotations, train_labels, val_annotations, val_labels = collect_training_data(data_dir)
    
    if not train_annotations:
        raise ValueError("No training data found! Please run data collection first.")
    
    print(f"\nLoaded {len(train_annotations)} training samples")
    if val_annotations:
        print(f"Loaded {len(val_annotations)} validation samples")
    
    # Create trainer with enhanced architecture
    print(f"\nCreating enhanced model with:")
    print(f"  Input dim: {INPUT_DIM}, Hidden dim: {HIDDEN_DIM}")
    print(f"  Layers: {NUM_LAYERS}, Heads: {NUM_HEADS}")
    print(f"  Focal loss: {USE_FOCAL_LOSS}, Mixed precision: {USE_MIXED_PRECISION}")
    
    trainer = ViabilityTrainer(
        config,
        input_dim=INPUT_DIM,
        hidden_dim=HIDDEN_DIM,
        num_layers=NUM_LAYERS,
        num_heads=NUM_HEADS,
        use_focal_loss=USE_FOCAL_LOSS,
        use_mixed_precision=USE_MIXED_PRECISION
    )
    
    # Train model with maximum quality settings
    print("\nStarting training with maximum quality configuration...")
    print("This will take a while but will produce the best possible model.")
    print()
    
    trainer.train(
        train_annotations, train_labels,
        val_annotations if val_annotations else None,
        val_labels if val_labels else None,
        epochs=EPOCHS,
        batch_size=BATCH_SIZE,
        learning_rate=LEARNING_RATE,
        weight_decay=WEIGHT_DECAY,
        warmup_epochs=WARMUP_EPOCHS,
        early_stopping_patience=EARLY_STOPPING_PATIENCE,
        gradient_clip_val=GRADIENT_CLIP_VAL
    )
    
    # Save final model
    final_model_path = config.models_dir / "viability_model_final.pth"
    trainer.save_model(final_model_path)
    
    print("\n" + "="*70)
    print("TRAINING COMPLETE!")
    print("="*70)
    print(f"Model saved to: {final_model_path}")
    
    # Check for best model
    best_model_path = config.models_dir / "viability_model_best.pth"
    if best_model_path.exists():
        print(f"Best model saved to: {best_model_path}")
    
    print("="*70)
    
except Exception as e:
    logger.error(f"Training failed: {e}", exc_info=True)
    raise


## Step 7: Verify Model


In [None]:
from pathlib import Path
import torch

# Check if model files exist
models_dir = Path('models')

print("="*70)
print("MODEL VERIFICATION")
print("="*70)

final_model = models_dir / "viability_model_final.pth"
best_model = models_dir / "viability_model_best.pth"

if final_model.exists():
    size_mb = final_model.stat().st_size / (1024 * 1024)
    print(f"✓ Final model: {final_model} ({size_mb:.2f} MB)")
    
    # Try loading the model
    try:
        checkpoint = torch.load(final_model, map_location='cpu')
        print(f"✓ Model loaded successfully")
        print(f"  Model class: {checkpoint.get('model_class', 'Unknown')}")
        print(f"  State dict keys: {len(checkpoint.get('model_state_dict', {}))}")
    except Exception as e:
        print(f"⚠ Error loading model: {e}")
else:
    print(f"⚠ Final model not found: {final_model}")

if best_model.exists():
    size_mb = best_model.stat().st_size / (1024 * 1024)
    print(f"✓ Best model: {best_model} ({size_mb:.2f} MB)")
else:
    print(f"ℹ Best model not found (this is okay if training didn't use validation)")

print("="*70)


## Step 8: Download Model


In [None]:
from google.colab import files
from pathlib import Path
import zipfile
import os

# Create a zip file with the trained models
models_dir = Path('models')
output_zip = '/tmp/vsim_trained_model.zip'

if models_dir.exists():
    print("Creating model archive...")
    
    with zipfile.ZipFile(output_zip, 'w', zipfile.ZIP_DEFLATED) as zipf:
        # Add all model files
        for model_file in models_dir.glob('*.pth'):
            zipf.write(model_file, model_file.name)
            print(f"  Added: {model_file.name}")
    
    # Download
    print(f"\nDownloading model archive...")
    files.download(output_zip)
    print("\n✓ Model downloaded successfully!")
    print("\nYou can now use this model in your VSim application.")
else:
    print("⚠ Models directory not found. Please run training first.")


## Summary

✅ **Data Collection**: Complete
✅ **Model Training**: Complete
✅ **Model Saved**: `models/viability_model_final.pth`

### Next Steps:
1. Download the model using Step 8 above
2. Use the model in your VSim application:
   ```python
   from src.vlab.viability.predictor import ViabilityPredictor
   import torch
   
   model = ViabilityPredictor()
   checkpoint = torch.load('viability_model_final.pth')
   model.load_state_dict(checkpoint['model_state_dict'])
   model.eval()
   ```

### Model Location:
- Final model: `models/viability_model_final.pth`
- Best model: `models/viability_model_best.pth` (if validation was used)
