In [None]:
import os
import sys
import subprocess
import torch
import platform
from pathlib import Path

In [None]:
print(f"Python: {sys.version}")
print(f"Platform: {platform.platform()}")
print(f"PyTorch: {torch.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA Version: {torch.version.cuda}")
    print(f"GPU Count: {torch.cuda.device_count()}")
    for i in range(torch.cuda.device_count()):
        print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
        props = torch.cuda.get_device_properties(i)
        print(f"  Memory: {props.total_memory / 1024**3:.2f} GB")
        print(f"  Compute Capability: {props.major}.{props.minor}")

In [None]:
subprocess.run([sys.executable, "-m", "pip", "install", "-q", "--upgrade", "pip"], check=True)
subprocess.run([sys.executable, "-m", "pip", "install", "-q", "torch==2.8.0", "torchvision", "torchaudio", "--index-url", "https://download.pytorch.org/whl/cu126"], check=True)
subprocess.run([sys.executable, "-m", "pip", "install", "-q", "torch-geometric"], check=True)
subprocess.run([sys.executable, "-m", "pip", "install", "-q", "pyg_lib", "torch_scatter", "torch_sparse", "torch_cluster", "torch_spline_conv", "-f", "https://data.pyg.org/whl/torch-2.8.0+cu126.html"], check=True)
subprocess.run([sys.executable, "-m", "pip", "install", "-q", "flash-attn", "--no-build-isolation"], check=True)
subprocess.run([sys.executable, "-m", "pip", "install", "-q", "xformers"], check=True)
subprocess.run([sys.executable, "-m", "pip", "install", "-q", "triton"], check=True)
subprocess.run([sys.executable, "-m", "pip", "install", "-q", "bitsandbytes"], check=True)
subprocess.run([sys.executable, "-m", "pip", "install", "-q", "transformers", "accelerate", "datasets"], check=True)
subprocess.run([sys.executable, "-m", "pip", "install", "-q", "wandb"], check=True)
subprocess.run([sys.executable, "-m", "pip", "install", "-q", "astropy", "astroquery"], check=True)
subprocess.run([sys.executable, "-m", "pip", "install", "-q", "biopython"], check=True)
subprocess.run([sys.executable, "-m", "pip", "install", "-q", "xarray", "zarr", "netCDF4"], check=True)
subprocess.run([sys.executable, "-m", "pip", "install", "-q", "h5py", "tables"], check=True)
subprocess.run([sys.executable, "-m", "pip", "install", "-q", "scipy", "scikit-learn", "pandas", "numpy"], check=True)
subprocess.run([sys.executable, "-m", "pip", "install", "-q", "matplotlib", "seaborn", "plotly"], check=True)
subprocess.run([sys.executable, "-m", "pip", "install", "-q", "tqdm", "rich"], check=True)
subprocess.run([sys.executable, "-m", "pip", "install", "-q", "pyyaml", "omegaconf"], check=True)
subprocess.run([sys.executable, "-m", "pip", "install", "-q", "einops"], check=True)
subprocess.run([sys.executable, "-m", "pip", "install", "-q", "timm"], check=True)
print("All packages installed successfully")

In [None]:
try:
    import flash_attn
    print(f"Flash Attention: {flash_attn.__version__}")
except ImportError:
    print("Flash Attention: Not available")
try:
    import xformers
    print(f"xFormers: {xformers.__version__}")
except ImportError:
    print("xFormers: Not available")
try:
    import bitsandbytes as bnb
    print(f"bitsandbytes: {bnb.__version__}")
except ImportError:
    print("bitsandbytes: Not available")
try:
    import torch_geometric
    print(f"PyTorch Geometric: {torch_geometric.__version__}")
except ImportError:
    print("PyTorch Geometric: Not available")

In [None]:
import warnings
warnings.filterwarnings('ignore')
import logging
import json
import time
import gc
from datetime import datetime, timedelta
from typing import Dict, List, Any, Optional, Tuple
from dataclasses import dataclass, field
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.cuda.amp import autocast, GradScaler
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
try:
    from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
    from torch.distributed.fsdp import CPUOffload
    FSDP_AVAILABLE = True
except ImportError:
    FSDP_AVAILABLE = False
try:
    import bitsandbytes as bnb
    BITSANDBYTES_AVAILABLE = True
except ImportError:
    BITSANDBYTES_AVAILABLE = False
try:
    import wandb
    WANDB_AVAILABLE = True
except ImportError:
    WANDB_AVAILABLE = False

In [None]:
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('/workspace/training.log'),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

In [None]:
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
os.environ['WORLD_SIZE'] = '2'
os.environ['RANK'] = '0'
os.environ['LOCAL_RANK'] = '0'
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
os.environ['NCCL_DEBUG'] = 'INFO'
os.environ['TORCH_DISTRIBUTED_DEBUG'] = 'DETAIL'
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512'
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

In [None]:
NASA_MAST_TOKEN = '54f271a4785a4ae19ffa5d0aff35c36c'
CLIMATE_DATA_STORE_KEY = '4dc6dcb0-c145-476f-baf9-d10eb524fb20'
NCBI_API_KEY = '64e1952dfbdd9791d8ec9b18ae2559ec0e09'
ESA_GAIA_USERNAME = 'sjiang02'
ESO_USERNAME = 'Shengboj324'
os.environ['MAST_API_TOKEN'] = NASA_MAST_TOKEN
os.environ['CDS_API_KEY'] = CLIMATE_DATA_STORE_KEY
os.environ['NCBI_API_KEY'] = NCBI_API_KEY

In [None]:
WORKSPACE_DIR = Path('/workspace')
DATA_DIR = WORKSPACE_DIR / 'data'
CHECKPOINT_DIR = WORKSPACE_DIR / 'checkpoints'
LOG_DIR = WORKSPACE_DIR / 'logs'
OUTPUT_DIR = WORKSPACE_DIR / 'outputs'
for dir_path in [DATA_DIR, CHECKPOINT_DIR, LOG_DIR, OUTPUT_DIR]:
    dir_path.mkdir(parents=True, exist_ok=True)
logger.info(f"Workspace directories created: {WORKSPACE_DIR}")

In [None]:
sys.path.insert(0, str(WORKSPACE_DIR))
from models.rebuilt_llm_integration import RebuiltLLMIntegration
from models.rebuilt_graph_vae import RebuiltGraphVAE
from models.rebuilt_datacube_cnn import RebuiltDatacubeCNN
from models.rebuilt_multimodal_integration import RebuiltMultimodalIntegration
from training.unified_multimodal_training import UnifiedMultiModalSystem, MultiModalTrainingConfig, compute_multimodal_loss
from data_build.unified_dataloader_architecture import MultiModalBatch, multimodal_collate_fn, DataLoaderConfig
from data_build.comprehensive_data_annotation_treatment import ComprehensiveDataAnnotationSystem, DataDomain, TreatmentConfig
from data_build.source_domain_mapping import get_source_domain_mapper  # NEW: 1000+ sources support
logger.info("All model modules imported successfully (with 1000+ sources annotation support)")

In [None]:
@dataclass
class RunPodTrainingConfig:
    num_gpus: int = 2
    gpu_type: str = "RTX_A5000"
    total_vram_gb: int = 48
    max_epochs: int = 200
    batch_size: int = 1
    gradient_accumulation_steps: int = 32
    effective_batch_size: int = 32
    learning_rate: float = 1e-4
    weight_decay: float = 1e-5
    warmup_steps: int = 1000
    use_mixed_precision: bool = True
    use_gradient_checkpointing: bool = True
    use_8bit_optimizer: bool = True
    use_cpu_offloading: bool = True
    use_flash_attention: bool = True
    use_distributed: bool = True
    distributed_backend: str = "nccl"
    save_every_n_epochs: int = 10
    checkpoint_every_hours: int = 2
    log_every_n_steps: int = 50
    use_wandb: bool = True
    wandb_project: str = "astrobiology-ai-training"
    wandb_entity: Optional[str] = None
    max_grad_norm: float = 1.0
    early_stopping_patience: int = 20
    target_accuracy: float = 0.96
    max_training_hours: int = 672
config = RunPodTrainingConfig()
logger.info(f"Training configuration initialized: {config}")

In [None]:
def setup_distributed(rank, world_size):
    os.environ['RANK'] = str(rank)
    os.environ['LOCAL_RANK'] = str(rank)
    os.environ['WORLD_SIZE'] = str(world_size)
    if not dist.is_initialized():
        dist.init_process_group(
            backend=config.distributed_backend,
            init_method='env://',
            world_size=world_size,
            rank=rank,
            timeout=timedelta(minutes=30)
        )
    torch.cuda.set_device(rank)
    logger.info(f"Distributed training initialized: rank={rank}, world_size={world_size}")
def cleanup_distributed():
    if dist.is_initialized():
        dist.destroy_process_group()
        logger.info("Distributed training cleaned up")

In [None]:
class GPUMonitor:
    def __init__(self, log_interval=30):
        self.log_interval = log_interval
        self.last_log_time = time.time()
        self.stats_history = []
    def get_gpu_stats(self):
        if not torch.cuda.is_available():
            return {}
        stats = {}
        for i in range(torch.cuda.device_count()):
            props = torch.cuda.get_device_properties(i)
            allocated = torch.cuda.memory_allocated(i) / 1024**3
            reserved = torch.cuda.memory_reserved(i) / 1024**3
            total = props.total_memory / 1024**3
            stats[f'gpu_{i}'] = {
                'allocated_gb': allocated,
                'reserved_gb': reserved,
                'total_gb': total,
                'utilization_pct': (allocated / total) * 100
            }
        return stats
    def log_stats(self, force=False):
        current_time = time.time()
        if force or (current_time - self.last_log_time) >= self.log_interval:
            stats = self.get_gpu_stats()
            self.stats_history.append({'timestamp': current_time, 'stats': stats})
            for gpu_id, gpu_stats in stats.items():
                logger.info(f"{gpu_id}: {gpu_stats['allocated_gb']:.2f}GB / {gpu_stats['total_gb']:.2f}GB ({gpu_stats['utilization_pct']:.1f}%)")
            self.last_log_time = current_time
            return stats
        return None
gpu_monitor = GPUMonitor(log_interval=30)
logger.info("GPU monitor initialized")

In [None]:
class MockMultiModalDataset(Dataset):
    def __init__(self, num_samples=10000, split='train'):
        self.num_samples = num_samples
        self.split = split
    def __len__(self):
        return self.num_samples
    def __getitem__(self, idx):
        np.random.seed(idx)
        climate_cube = torch.randn(2, 12, 32, 64, 10)
        from torch_geometric.data import Data as PyGData
        num_nodes = 50
        num_edges = 100
        edge_index = torch.randint(0, num_nodes, (2, num_edges))
        node_features = torch.randn(num_nodes, 64)
        bio_graph = PyGData(x=node_features, edge_index=edge_index, num_nodes=num_nodes)
        spectroscopy = torch.randn(1000, 3)
        text_description = f"Exoplanet {idx}: habitability assessment"
        habitability_label = torch.randint(0, 2, (1,)).item()
        return {
            'run_id': idx,
            'planet_params': torch.randn(10),
            'climate_cube': climate_cube,
            'bio_graph': bio_graph,
            'spectrum': spectroscopy,
            'text_description': text_description,
            'habitability_label': habitability_label,
            'metadata': {'split': self.split}
        }
logger.info("Mock dataset class defined")

In [None]:
def custom_collate_fn(batch):
    from torch_geometric.data import Batch as PyGBatch
    run_ids = torch.tensor([item['run_id'] for item in batch], dtype=torch.long)
    planet_params = torch.stack([item['planet_params'] for item in batch])
    climate_cubes = torch.stack([item['climate_cube'] for item in batch])
    bio_graphs = PyGBatch.from_data_list([item['bio_graph'] for item in batch])
    spectra = torch.stack([item['spectrum'] for item in batch])
    text_descriptions = [item['text_description'] for item in batch]
    habitability_labels = torch.tensor([item['habitability_label'] for item in batch], dtype=torch.long)
    metadata = [item['metadata'] for item in batch]
    return {
        'run_ids': run_ids,
        'planet_params': planet_params,
        'climate_datacube': climate_cubes,
        'metabolic_graph': bio_graphs,
        'spectroscopy': spectra,
        'text_description': text_descriptions,
        'habitability_label': habitability_labels,
        'metadata': metadata
    }
logger.info("Custom collate function defined")

In [None]:
train_dataset = MockMultiModalDataset(num_samples=8000, split='train')
val_dataset = MockMultiModalDataset(num_samples=1000, split='val')
test_dataset = MockMultiModalDataset(num_samples=1000, split='test')
train_loader = DataLoader(
    train_dataset,
    batch_size=config.batch_size,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    collate_fn=custom_collate_fn,
    persistent_workers=True
)
val_loader = DataLoader(
    val_dataset,
    batch_size=config.batch_size,
    shuffle=False,
    num_workers=2,
    pin_memory=True,
    collate_fn=custom_collate_fn,
    persistent_workers=True
)
test_loader = DataLoader(
    test_dataset,
    batch_size=config.batch_size,
    shuffle=False,
    num_workers=2,
    pin_memory=True,
    collate_fn=custom_collate_fn,
    persistent_workers=True
)
logger.info(f"Data loaders created: train={len(train_loader)}, val={len(val_loader)}, test={len(test_loader)}")

In [None]:
annotation_system = ComprehensiveDataAnnotationSystem()
treatment_config = TreatmentConfig()
logger.info("âœ… Comprehensive data annotation system initialized (supports 14 domains)")
annotation_stats = annotation_system.get_statistics()
logger.info(f"ðŸ“Š Annotation statistics: {annotation_stats}")

# Initialize source mapper for 1000+ sources
source_mapper = get_source_domain_mapper()
mapper_stats = source_mapper.get_statistics()
logger.info(f"âœ… Source mapper initialized: {mapper_stats['total_sources']} sources across {len(mapper_stats['by_domain'])} domains")
logger.info(f"ðŸ“Š Sources by domain: {mapper_stats['by_domain']}")
logger.info(f"ðŸ“Š Average quality score: {mapper_stats['avg_quality_score']:.3f}")

In [None]:
multimodal_config = MultiModalTrainingConfig(
    llm_config={
        'hidden_size': 4352,
        'num_attention_heads': 64,
        'use_4bit_quantization': False,
        'use_lora': False,
        'use_scientific_reasoning': True,
        'use_rope': True,
        'use_gqa': True,
        'use_rms_norm': True,
        'use_swiglu': True
    },
    graph_config={
        'node_features': 64,
        'hidden_dim': 512,
        'latent_dim': 256,
        'max_nodes': 50,
        'num_layers': 6,
        'heads': 16,
        'use_biochemical_constraints': True
    },
    cnn_config={
        'input_variables': 2,
        'output_variables': 2,
        'base_channels': 128,
        'depth': 6,
        'use_attention': True,
        'use_physics_constraints': True,
        'embed_dim': 256,
        'num_heads': 8,
        'num_transformer_layers': 6,
        'use_vit_features': True,
        'use_gradient_checkpointing': True
    },
    fusion_config={
        'fusion_dim': 1024,
        'num_attention_heads': 8,
        'num_fusion_layers': 3,
        'use_adaptive_weighting': True
    },
    classification_weight=1.0,
    reconstruction_weight=0.1,
    physics_weight=0.2,
    consistency_weight=0.15,
    batch_size=config.batch_size,
    gradient_accumulation_steps=config.gradient_accumulation_steps,
    use_gradient_checkpointing=config.use_gradient_checkpointing,
    use_mixed_precision=config.use_mixed_precision,
    use_8bit_optimizer=config.use_8bit_optimizer,
    device='cuda'
)
logger.info("Multi-modal training configuration created")

In [None]:
model = UnifiedMultiModalSystem(multimodal_config)
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
logger.info(f"Model initialized: {total_params/1e9:.2f}B total params, {trainable_params/1e9:.2f}B trainable")
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
logger.info(f"Model moved to device: {device}")

In [None]:
if config.use_gradient_checkpointing:
    if hasattr(model, 'gradient_checkpointing_enable'):
        model.gradient_checkpointing_enable()
        logger.info("Gradient checkpointing enabled")
    for name, module in model.named_modules():
        if hasattr(module, 'gradient_checkpointing'):
            module.gradient_checkpointing = True
        if hasattr(module, 'use_checkpoint'):
            module.use_checkpoint = True

In [None]:
if config.use_8bit_optimizer and BITSANDBYTES_AVAILABLE:
    optimizer = bnb.optim.AdamW8bit(
        model.parameters(),
        lr=config.learning_rate,
        weight_decay=config.weight_decay,
        betas=(0.9, 0.999),
        eps=1e-8
    )
    logger.info("8-bit AdamW optimizer initialized (75% memory reduction)")
else:
    optimizer = optim.AdamW(
        model.parameters(),
        lr=config.learning_rate,
        weight_decay=config.weight_decay
    )
    logger.info("Standard AdamW optimizer initialized")

In [None]:
total_steps = len(train_loader) * config.max_epochs // config.gradient_accumulation_steps
scheduler = optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=config.learning_rate * 10,
    total_steps=total_steps,
    pct_start=0.3,
    anneal_strategy='cos',
    div_factor=10.0,
    final_div_factor=100.0
)
logger.info(f"OneCycleLR scheduler initialized: {total_steps} total steps")

In [None]:
scaler = GradScaler() if config.use_mixed_precision else None
if scaler:
    logger.info("Mixed precision scaler initialized")

In [None]:
if config.use_wandb and WANDB_AVAILABLE:
    wandb.init(
        project=config.wandb_project,
        entity=config.wandb_entity,
        config={
            'num_gpus': config.num_gpus,
            'gpu_type': config.gpu_type,
            'total_vram_gb': config.total_vram_gb,
            'max_epochs': config.max_epochs,
            'batch_size': config.batch_size,
            'gradient_accumulation_steps': config.gradient_accumulation_steps,
            'effective_batch_size': config.effective_batch_size,
            'learning_rate': config.learning_rate,
            'weight_decay': config.weight_decay,
            'use_mixed_precision': config.use_mixed_precision,
            'use_gradient_checkpointing': config.use_gradient_checkpointing,
            'use_8bit_optimizer': config.use_8bit_optimizer,
            'use_cpu_offloading': config.use_cpu_offloading,
            'total_params': total_params,
            'trainable_params': trainable_params
        },
        name=f"astrobio_training_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
    )
    logger.info("Weights & Biases initialized")

In [None]:
def save_checkpoint(epoch, model, optimizer, scheduler, scaler, metrics, is_best=False):
    checkpoint_path = CHECKPOINT_DIR / f"checkpoint_epoch_{epoch}.pt"
    best_path = CHECKPOINT_DIR / "best_model.pt"
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
        'scaler_state_dict': scaler.state_dict() if scaler else None,
        'metrics': metrics,
        'config': config,
        'timestamp': datetime.now().isoformat()
    }
    torch.save(checkpoint, checkpoint_path)
    logger.info(f"Checkpoint saved: {checkpoint_path}")
    if is_best:
        torch.save(checkpoint, best_path)
        logger.info(f"Best model saved: {best_path}")
    old_checkpoints = sorted(CHECKPOINT_DIR.glob("checkpoint_epoch_*.pt"))
    if len(old_checkpoints) > 5:
        for old_ckpt in old_checkpoints[:-5]:
            old_ckpt.unlink()
            logger.info(f"Removed old checkpoint: {old_ckpt}")

In [None]:
def train_epoch(epoch, model, train_loader, optimizer, scheduler, scaler, device, config):
    model.train()
    epoch_loss = 0.0
    epoch_metrics = {'classification': 0.0, 'llm': 0.0, 'graph_vae': 0.0, 'total': 0.0}
    num_batches = len(train_loader)
    accumulation_steps = config.gradient_accumulation_steps
    optimizer.zero_grad()
    start_time = time.time()
    last_checkpoint_time = start_time
    for batch_idx, batch in enumerate(train_loader):
        for key in batch:
            if isinstance(batch[key], torch.Tensor):
                batch[key] = batch[key].to(device)
            elif hasattr(batch[key], 'to'):
                batch[key] = batch[key].to(device)
        if config.use_mixed_precision and scaler is not None:
            with autocast():
                outputs = model(batch)
                total_loss, loss_dict = compute_multimodal_loss(outputs, batch, model.config)
        else:
            outputs = model(batch)
            total_loss, loss_dict = compute_multimodal_loss(outputs, batch, model.config)
        loss = total_loss / accumulation_steps
        if config.use_mixed_precision and scaler is not None:
            scaler.scale(loss).backward()
        else:
            loss.backward()
        if (batch_idx + 1) % accumulation_steps == 0 or (batch_idx + 1) == num_batches:
            if config.use_mixed_precision and scaler is not None:
                scaler.unscale_(optimizer)
                grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
                scaler.step(optimizer)
                scaler.update()
            else:
                grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
                optimizer.step()
            optimizer.zero_grad()
            if scheduler is not None:
                scheduler.step()
        epoch_loss += total_loss.item()
        for key in loss_dict:
            if key in epoch_metrics:
                epoch_metrics[key] += loss_dict[key]
        if (batch_idx + 1) % config.log_every_n_steps == 0:
            avg_loss = epoch_loss / (batch_idx + 1)
            lr = optimizer.param_groups[0]['lr']
            logger.info(f"Epoch {epoch} [{batch_idx+1}/{num_batches}] Loss: {avg_loss:.4f} LR: {lr:.2e}")
            if config.use_wandb and WANDB_AVAILABLE:
                wandb.log({
                    'train/loss': avg_loss,
                    'train/lr': lr,
                    'train/epoch': epoch,
                    'train/step': epoch * num_batches + batch_idx
                })
        current_time = time.time()
        if (current_time - last_checkpoint_time) >= (config.checkpoint_every_hours * 3600):
            save_checkpoint(epoch, model, optimizer, scheduler, scaler, {'loss': epoch_loss / (batch_idx + 1)}, is_best=False)
            last_checkpoint_time = current_time
        gpu_monitor.log_stats()
    epoch_time = time.time() - start_time
    avg_epoch_loss = epoch_loss / num_batches
    for key in epoch_metrics:
        epoch_metrics[key] /= num_batches
    logger.info(f"Epoch {epoch} completed in {epoch_time:.2f}s - Avg Loss: {avg_epoch_loss:.4f}")
    return avg_epoch_loss, epoch_metrics

In [None]:
def validate_epoch(epoch, model, val_loader, device, config):
    model.eval()
    val_loss = 0.0
    val_metrics = {'classification': 0.0, 'llm': 0.0, 'graph_vae': 0.0, 'total': 0.0}
    num_batches = len(val_loader)
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, batch in enumerate(val_loader):
            for key in batch:
                if isinstance(batch[key], torch.Tensor):
                    batch[key] = batch[key].to(device)
                elif hasattr(batch[key], 'to'):
                    batch[key] = batch[key].to(device)
            if config.use_mixed_precision:
                with autocast():
                    outputs = model(batch)
                    total_loss, loss_dict = compute_multimodal_loss(outputs, batch, model.config)
            else:
                outputs = model(batch)
                total_loss, loss_dict = compute_multimodal_loss(outputs, batch, model.config)
            val_loss += total_loss.item()
            for key in loss_dict:
                if key in val_metrics:
                    val_metrics[key] += loss_dict[key]
            if 'logits' in outputs and 'habitability_label' in batch:
                predictions = torch.argmax(outputs['logits'], dim=1)
                correct += (predictions == batch['habitability_label']).sum().item()
                total += batch['habitability_label'].size(0)
    avg_val_loss = val_loss / num_batches
    for key in val_metrics:
        val_metrics[key] /= num_batches
    accuracy = correct / total if total > 0 else 0.0
    logger.info(f"Validation Epoch {epoch} - Loss: {avg_val_loss:.4f} Accuracy: {accuracy:.4f}")
    if config.use_wandb and WANDB_AVAILABLE:
        wandb.log({
            'val/loss': avg_val_loss,
            'val/accuracy': accuracy,
            'val/epoch': epoch
        })
    return avg_val_loss, accuracy, val_metrics

In [None]:
best_val_loss = float('inf')
best_val_accuracy = 0.0
patience_counter = 0
training_start_time = time.time()
training_history = []
logger.info("="*80)
logger.info("STARTING TRAINING")
logger.info("="*80)
logger.info(f"Configuration:")
logger.info(f"  GPUs: {config.num_gpus} x {config.gpu_type}")
logger.info(f"  Total VRAM: {config.total_vram_gb} GB")
logger.info(f"  Max Epochs: {config.max_epochs}")
logger.info(f"  Batch Size: {config.batch_size}")
logger.info(f"  Gradient Accumulation: {config.gradient_accumulation_steps}")
logger.info(f"  Effective Batch Size: {config.effective_batch_size}")
logger.info(f"  Learning Rate: {config.learning_rate}")
logger.info(f"  Mixed Precision: {config.use_mixed_precision}")
logger.info(f"  Gradient Checkpointing: {config.use_gradient_checkpointing}")
logger.info(f"  8-bit Optimizer: {config.use_8bit_optimizer}")
logger.info(f"  CPU Offloading: {config.use_cpu_offloading}")
logger.info(f"  Model Parameters: {total_params/1e9:.2f}B")
logger.info("="*80)

In [None]:
for epoch in range(1, config.max_epochs + 1):
    logger.info(f"\n{'='*80}")
    logger.info(f"EPOCH {epoch}/{config.max_epochs}")
    logger.info(f"{'='*80}")
    train_loss, train_metrics = train_epoch(
        epoch, model, train_loader, optimizer, scheduler, scaler, device, config
    )
    val_loss, val_accuracy, val_metrics = validate_epoch(
        epoch, model, val_loader, device, config
    )
    epoch_history = {
        'epoch': epoch,
        'train_loss': train_loss,
        'train_metrics': train_metrics,
        'val_loss': val_loss,
        'val_accuracy': val_accuracy,
        'val_metrics': val_metrics,
        'timestamp': datetime.now().isoformat()
    }
    training_history.append(epoch_history)
    is_best = False
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0
        is_best = True
        logger.info(f"New best validation loss: {best_val_loss:.4f}")
    else:
        patience_counter += 1
        logger.info(f"No improvement. Patience: {patience_counter}/{config.early_stopping_patience}")
    if val_accuracy > best_val_accuracy:
        best_val_accuracy = val_accuracy
        logger.info(f"New best validation accuracy: {best_val_accuracy:.4f}")
    if epoch % config.save_every_n_epochs == 0 or is_best:
        save_checkpoint(epoch, model, optimizer, scheduler, scaler, epoch_history, is_best=is_best)
    if patience_counter >= config.early_stopping_patience:
        logger.info(f"Early stopping triggered after {epoch} epochs")
        break
    if val_accuracy >= config.target_accuracy:
        logger.info(f"Target accuracy {config.target_accuracy:.2%} reached!")
        save_checkpoint(epoch, model, optimizer, scheduler, scaler, epoch_history, is_best=True)
        break
    elapsed_hours = (time.time() - training_start_time) / 3600
    if elapsed_hours >= config.max_training_hours:
        logger.info(f"Maximum training time {config.max_training_hours} hours reached")
        save_checkpoint(epoch, model, optimizer, scheduler, scaler, epoch_history, is_best=False)
        break
    gpu_monitor.log_stats(force=True)
    torch.cuda.empty_cache()
    gc.collect()

In [None]:
total_training_time = time.time() - training_start_time
logger.info("="*80)
logger.info("TRAINING COMPLETED")
logger.info("="*80)
logger.info(f"Total Training Time: {total_training_time/3600:.2f} hours")
logger.info(f"Best Validation Loss: {best_val_loss:.4f}")
logger.info(f"Best Validation Accuracy: {best_val_accuracy:.4f}")
logger.info(f"Total Epochs: {len(training_history)}")
logger.info("="*80)

In [None]:
history_path = LOG_DIR / f"training_history_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
with open(history_path, 'w') as f:
    json.dump(training_history, f, indent=2)
logger.info(f"Training history saved: {history_path}")

In [None]:
if config.use_wandb and WANDB_AVAILABLE:
    wandb.log({
        'final/best_val_loss': best_val_loss,
        'final/best_val_accuracy': best_val_accuracy,
        'final/total_epochs': len(training_history),
        'final/total_training_hours': total_training_time / 3600
    })
    wandb.finish()
    logger.info("Weights & Biases run finished")

In [None]:
logger.info("Testing best model on test set...")
best_checkpoint = torch.load(CHECKPOINT_DIR / "best_model.pt")
model.load_state_dict(best_checkpoint['model_state_dict'])
model.eval()
test_loss = 0.0
test_correct = 0
test_total = 0
with torch.no_grad():
    for batch_idx, batch in enumerate(test_loader):
        for key in batch:
            if isinstance(batch[key], torch.Tensor):
                batch[key] = batch[key].to(device)
            elif hasattr(batch[key], 'to'):
                batch[key] = batch[key].to(device)
        outputs = model(batch)
        total_loss, loss_dict = compute_multimodal_loss(outputs, batch, model.config)
        test_loss += total_loss.item()
        if 'logits' in outputs and 'habitability_label' in batch:
            predictions = torch.argmax(outputs['logits'], dim=1)
            test_correct += (predictions == batch['habitability_label']).sum().item()
            test_total += batch['habitability_label'].size(0)
test_loss /= len(test_loader)
test_accuracy = test_correct / test_total if test_total > 0 else 0.0
logger.info("="*80)
logger.info("TEST SET RESULTS")
logger.info("="*80)
logger.info(f"Test Loss: {test_loss:.4f}")
logger.info(f"Test Accuracy: {test_accuracy:.4f}")
logger.info("="*80)

In [None]:
final_stats = gpu_monitor.get_gpu_stats()
logger.info("Final GPU Statistics:")
for gpu_id, stats in final_stats.items():
    logger.info(f"{gpu_id}: {stats['allocated_gb']:.2f}GB / {stats['total_gb']:.2f}GB ({stats['utilization_pct']:.1f}%)")

In [None]:
logger.info("Training notebook execution completed successfully!")


