# Real Data Ingestion for Training Pipeline

This notebook ingests real astronomical data to populate the database for training.

## Purpose
- Fetch real astronomical images from SkyView
- Create observation records in database
- Generate mock detections for training
- Enable real data training pipeline testing

## Based on
- ASTR-74: Survey Integration (SkyView/MAST)
- ASTR-73: Observation Models
- data_ingestion_exploration.ipynb examples


In [1]:
# Setup and imports
import sys
import os
import asyncio
import numpy as np
from pathlib import Path
from datetime import datetime, timedelta
from uuid import uuid4

# Add project root to path
project_root = Path.cwd().parent
sys.path.insert(0, str(project_root))

print(f"📍 Project root: {project_root}")
print(f"📁 Current working directory: {Path.cwd()}")
print("✅ Path setup complete")


📍 Project root: /home/chris/github/AstrID
📁 Current working directory: /home/chris/github/AstrID/notebooks
✅ Path setup complete


In [5]:
from src.core.constants import TRAINING_PIPELINE_API_KEY

global AUTH_HEADERS
AUTH_HEADERS = {
    "X-API-Key": TRAINING_PIPELINE_API_KEY,
    "Content-Type": "application/json",
}

In [6]:
# Import AstrID components
from src.core.db.session import AsyncSessionLocal
from src.domains.observations.models import Survey, Observation, ObservationStatus
from src.domains.detection.models import Detection, DetectionType, DetectionStatus, Model, ModelRun, ModelType, ModelRunStatus
from src.adapters.external.skyview import SkyViewClient
from src.infrastructure.storage.r2_client import R2StorageClient
from src.adapters.imaging.fits_io import FITSProcessor

print("✅ AstrID components imported successfully")


✅ AstrID components imported successfully


In [7]:
# Create database session
db_session = AsyncSessionLocal()
print("✅ Database session created")

# Initialize clients
skyview_client = SkyViewClient(timeout=60)
r2_client = R2StorageClient()
fits_processor = FITSProcessor()

print("✅ External clients initialized")


✅ Database session created
✅ External clients initialized


In [4]:
# Step 1: Create a survey record
async def create_survey():
    """Create a survey record in the database."""
    
    # Check if survey already exists
    from sqlalchemy import select
    result = await db_session.execute(
        select(Survey).where(Survey.name == "DSS2")
    )
    existing_survey = result.scalar_one_or_none()
    
    if existing_survey:
        print(f"✅ Survey 'DSS2' already exists: {existing_survey.id}")
        return existing_survey
    
    # Create new survey
    survey = Survey(
        name="DSS2",
        description="Digitized Sky Survey 2 - Wide field astronomical images",
        base_url="https://skyview.gsfc.nasa.gov",
        api_endpoint="https://skyview.gsfc.nasa.gov/current/cgi",
        is_active=True
    )
    
    db_session.add(survey)
    await db_session.commit()
    
    print(f"✅ Created survey 'DSS2': {survey.id}")
    return survey

# Create survey
survey = await create_survey()


SSLCertVerificationError: [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: self-signed certificate in certificate chain (_ssl.c:1000)

In [None]:
# Step 2: Ingest real observations from SkyView
async def ingest_real_observations(survey, num_observations=10):
    """Ingest real astronomical observations from SkyView."""
    
    observations = []
    
    # Famous astronomical coordinates for interesting objects
    target_coordinates = [
        (83.633, 22.0145, "M1_Crab_Nebula"),        # Crab Nebula
        (84.053, -5.391, "M42_Orion_Nebula"),      # Orion Nebula  
        (186.265, 12.717, "M104_Sombrero_Galaxy"),  # Sombrero Galaxy
        (194.046, 54.349, "M51_Whirlpool_Galaxy"),  # Whirlpool Galaxy
        (187.706, 12.391, "M87_Giant_Galaxy"),      # M87 Galaxy
        (279.234, 38.784, "Ring_Nebula"),           # Ring Nebula region
        (310.358, 40.257, "Vega_region"),           # Vega region
        (201.365, -11.161, "Spica_region"),         # Spica region
        (68.980, 16.509, "Aldebaran_region"),       # Aldebaran region
        (150.789, 12.176, "Regulus_region"),        # Regulus region
    ]
    
    for i, (ra, dec, name) in enumerate(target_coordinates[:num_observations]):
        try:
            print(f"\\n🔍 Ingesting observation {i+1}/{num_observations}: {name}")
            print(f"   Coordinates: RA={ra:.3f}°, Dec={dec:.3f}°")
            
            # Fetch image from SkyView
            print(f"   📡 Fetching from SkyView...")
            img, info = skyview_client.fetch_reference_image(
                ra_deg=ra,
                dec_deg=dec,
                size_pixels=512,
                fov_deg=0.1,  # 0.1 degree field of view
                survey="DSS2 Red",
            )
            
            if img is None:
                print(f"   ❌ Failed to fetch image: {info.get('error', 'Unknown error')}")
                continue
            
            print(f"   ✅ Image fetched: {img.shape} pixels")
            
            # Save FITS file to temporary location (would be R2 in production)
            fits_filename = f"temp_{name}_{i}.fits"
            fits_path = f"/tmp/{fits_filename}"
            
            # Create observation record
            observation = Observation(
                survey_id=survey.id,
                observation_id=f"skyview_dss2_{i}_{int(ra)}_{int(dec)}",
                ra=ra,
                dec=dec,
                observation_time=datetime.now() - timedelta(days=np.random.randint(1, 365)),
                filter_band="R",  # DSS2 Red
                exposure_time=300.0,  # Typical exposure
                fits_url=f"https://skyview.gsfc.nasa.gov/tempspace/fits/{fits_filename}",
                fits_file_path=fits_path,
                pixel_scale=1.7,  # arcsec/pixel for DSS
                image_width=img.shape[1] if len(img.shape) > 1 else img.shape[0],
                image_height=img.shape[0],
                status=ObservationStatus.PREPROCESSED,  # Mark as ready for detection
            )
            
            db_session.add(observation)
            await db_session.flush()  # Get the ID
            
            observations.append((observation, img))
            print(f"   ✅ Created observation record: {observation.id}")
            
        except Exception as e:
            print(f"   ❌ Error ingesting {name}: {e}")
            continue
    
    await db_session.commit()
    print(f"\\n🎉 Successfully ingested {len(observations)} observations!")
    return observations

# Ingest observations
observations_data = await ingest_real_observations(survey, num_observations=5)


In [None]:
# Step 3: Create a mock ML model record
async def create_mock_model():
    """Create a mock ML model for generating detections."""
    
    from sqlalchemy import select
    result = await db_session.execute(
        select(Model).where(Model.name == "mock_unet_v1")
    )
    existing_model = result.scalar_one_or_none()
    
    if existing_model:
        print(f"✅ Model 'mock_unet_v1' already exists: {existing_model.id}")
        return existing_model
    
    model = Model(
        name="mock_unet_v1",
        version="1.0.0",
        model_type=ModelType.UNET,
        architecture={
            "input_channels": 1,
            "output_channels": 1,
            "depth": 4,
            "initial_filters": 64
        },
        hyperparameters={
            "learning_rate": 0.001,
            "batch_size": 2
        },
        training_dataset="real_skyview_data",
        precision=0.85,
        recall=0.78,
        f1_score=0.81,
        accuracy=0.89,
        is_active=True
    )
    
    db_session.add(model)
    await db_session.commit()
    
    print(f"✅ Created mock model: {model.id}")
    return model

# Create model
model = await create_mock_model()


In [None]:
# Step 4: Generate mock detections with realistic coordinates
async def create_mock_detections(observations_data, model):
    """Create mock detections for the ingested observations."""
    
    detections = []
    
    for observation, img in observations_data:
        # Create a model run record
        model_run = ModelRun(
            model_id=model.id,
            observation_id=observation.id,
            input_image_path=observation.fits_file_path,
            inference_time_ms=np.random.randint(50, 200),
            memory_usage_mb=np.random.randint(1000, 2000),
            total_predictions=np.random.randint(1, 5),
            high_confidence_predictions=np.random.randint(0, 3),
            status=ModelRunStatus.COMPLETED,
            started_at=datetime.now() - timedelta(minutes=np.random.randint(1, 60)),
            completed_at=datetime.now() - timedelta(minutes=np.random.randint(0, 30))
        )
        
        db_session.add(model_run)
        await db_session.flush()  # Get the ID
        
        # Generate 1-3 detections per observation
        num_detections = np.random.randint(1, 4)
        
        for j in range(num_detections):
            # Generate realistic pixel coordinates within image bounds
            img_height, img_width = img.shape[:2] if len(img.shape) > 1 else (img.shape[0], img.shape[0])
            
            pixel_x = np.random.randint(50, img_width - 50)   # Avoid edges
            pixel_y = np.random.randint(50, img_height - 50)  # Avoid edges
            
            # Convert pixel to sky coordinates (simplified)
            pixel_scale = observation.pixel_scale or 1.7  # arcsec/pixel
            ra_offset = (pixel_x - img_width/2) * pixel_scale / 3600.0  # degrees
            dec_offset = (pixel_y - img_height/2) * pixel_scale / 3600.0  # degrees
            
            det_ra = observation.ra + ra_offset
            det_dec = observation.dec + dec_offset
            
            # Create detection with varying confidence and types
            detection_types = [DetectionType.TRANSIENT, DetectionType.VARIABLE, DetectionType.SUPERNOVA, DetectionType.UNKNOWN]
            anomaly_types = ["supernova", "variable_star", "transient", "asteroid", "unknown"]
            
            detection = Detection(
                model_run_id=model_run.id,
                observation_id=observation.id,
                ra=det_ra,
                dec=det_dec,
                pixel_x=pixel_x,
                pixel_y=pixel_y,
                confidence_score=np.random.uniform(0.65, 0.95),  # High confidence
                detection_type=np.random.choice(detection_types),
                model_version="1.0.0",
                inference_time_ms=np.random.randint(10, 50),
                status=DetectionStatus.VALIDATED,  # Mark as validated for training
                is_validated=True,
                validation_confidence=np.random.uniform(0.7, 0.9),
                human_label=np.random.choice(anomaly_types),
                prediction_metadata={
                    "detection_method": "unet_segmentation",
                    "patch_size": [64, 64],
                    "preprocessing_applied": True
                }
            )
            
            db_session.add(detection)
            detections.append(detection)
        
        print(f"   ✅ Created {num_detections} detections for {observation.observation_id}")
    
    await db_session.commit()
    print(f"\\n🎉 Created {len(detections)} total detections!")
    return detections

# Create detections
detections = await create_mock_detections(observations_data, model)


In [None]:
# Step 5: Verify data ingestion
async def verify_ingestion():
    """Verify that real data was ingested successfully."""
    
    from sqlalchemy import select, func
    
    # Count observations
    obs_count = await db_session.execute(select(func.count(Observation.id)))
    obs_total = obs_count.scalar()
    
    # Count detections
    det_count = await db_session.execute(select(func.count(Detection.id)))
    det_total = det_count.scalar()
    
    # Count validated detections
    val_det_count = await db_session.execute(
        select(func.count(Detection.id)).where(Detection.is_validated == True)
    )
    val_det_total = val_det_count.scalar()
    
    # Count surveys
    survey_count = await db_session.execute(select(func.count(Survey.id)))
    survey_total = survey_count.scalar()
    
    print("📊 Database Content After Ingestion:")
    print(f"   🔭 Surveys: {survey_total}")
    print(f"   📊 Observations: {obs_total}")
    print(f"   🎯 Detections: {det_total}")
    print(f"   ✅ Validated detections: {val_det_total}")
    
    # Sample some data
    if val_det_total > 0:
        sample_detections = await db_session.execute(
            select(Detection).where(Detection.is_validated == True).limit(3)
        )
        
        print(f"\\n🔬 Sample Validated Detections:")
        for det in sample_detections.scalars():
            print(f"   • ID: {det.id}")
            print(f"     Type: {det.detection_type.value}")
            print(f"     Confidence: {det.confidence_score:.3f}")
            print(f"     Human Label: {det.human_label}")
            print(f"     Coordinates: ({det.ra:.3f}, {det.dec:.3f})")
    
    return {
        "surveys": survey_total,
        "observations": obs_total,
        "detections": det_total,
        "validated_detections": val_det_total,
    }

# Verify ingestion
verification_results = await verify_ingestion()


In [None]:
# Step 6: Test real data collection for training
async def test_real_data_collection():
    """Test the real data collection for training pipeline."""
    
    print("🧪 Testing Real Data Collection for Training...")
    
    try:
        from src.domains.ml.training_data.services import (
            TrainingDataCollector,
            TrainingDataCollectionParams
        )
        from datetime import datetime, timedelta
        
        # Set up collection parameters
        end_date = datetime.now()
        start_date = end_date - timedelta(days=365)
        
        collection_params = TrainingDataCollectionParams(
            survey_ids=["DSS2"],
            date_range=(start_date, end_date),
            confidence_threshold=0.6,
            max_samples=100,
            validation_status="validated",
        )
        
        # Test data collection
        collector = TrainingDataCollector(db_session, r2_client)
        samples = await collector.collect_training_data(collection_params)
        
        print(f"\\n✅ Real Data Collection Test Results:")
        print(f"   📊 Samples collected: {len(samples)}")
        
        if samples:
            # Show sample details
            for i, sample in enumerate(samples[:3]):
                print(f"   Sample {i+1}:")
                print(f"     Labels: {sample.labels}")
                print(f"     Metadata: {sample.sample_metadata}")
            
            # Test quality validation
            quality_report = collector.validate_data_quality(samples)
            print(f"\\n📈 Quality Report:")
            print(f"   Total samples: {quality_report.total_samples}")
            print(f"   Anomaly ratio: {quality_report.anomaly_ratio:.3f}")
            print(f"   Quality score: {quality_report.quality_score:.3f}")
            print(f"   Issues: {quality_report.issues}")
            
            return True
        else:
            print("   ❌ No samples collected")
            return False
            
    except Exception as e:
        print(f"   ❌ Error testing data collection: {e}")
        import traceback
        traceback.print_exc()
        return False

# Test collection
collection_success = await test_real_data_collection()


In [None]:
# Step 7: Summary and next steps
print("\\n" + "=" * 60)
print("🎉 REAL DATA INGESTION COMPLETE!")
print("=" * 60)

print(f"\\n📊 Ingestion Summary:")
print(f"   Surveys created: {verification_results['surveys']}")
print(f"   Observations ingested: {verification_results['observations']}")
print(f"   Detections created: {verification_results['detections']}")
print(f"   Validated detections: {verification_results['validated_detections']}")

if collection_success:
    print(f"\\n✅ READY FOR REAL DATA TRAINING!")
    print(f"   The training pipeline can now access real astronomical data")
    print(f"   GPU utilization should show meaningful values during training")
    print(f"   Energy tracking will reflect actual compute work")
else:
    print(f"\\n⚠️  Real data collection test failed")
    print(f"   Check the data collection parameters and database connectivity")

print(f"\\n🚀 Next Steps:")
print(f"   1. Run the training notebook again")
print(f"   2. Verify real data loading works")
print(f"   3. Check GPU utilization during training")
print(f"   4. Monitor energy consumption values")

print(f"\\n🔗 Database Status:")
print(f"   Ready for real data training: {'✅' if collection_success else '❌'}")
print(f"   Synthetic fallback available: ✅")

# Clean up session
await db_session.close()
print(f"\\n✅ Database session closed")
