# LEAP Pangeo JupyterHub Startup Cell

This cell performs environment checks, package installations, and setup for the LEAP hackathon.

**Best Practices:**
- ‚ö†Ô∏è Do NOT store large data in your home directory (limited space)
- ‚úÖ Use `/home/jovyan/leap-scratch/<your-name>/` for outputs and temporary files
- ‚úÖ Read data directly from cloud storage (OSN) using s3fs/fsspec

In [None]:
# ============================================================================
# LEAP Pangeo JupyterHub Startup Cell
# ============================================================================

import sys
import subprocess
import warnings

print("=" * 80)
print("LEAP PANGEO JUPYTERHUB ENVIRONMENT CHECK")
print("=" * 80)

# ----------------------------------------------------------------------------
# 1. Check Python Version
# ----------------------------------------------------------------------------
print(f"\nüìç Python Version: {sys.version}")
print(f"üìç Python Executable: {sys.executable}")

# ----------------------------------------------------------------------------
# 2. Check GPU Availability
# ----------------------------------------------------------------------------
print("\n" + "=" * 80)
print("GPU AVAILABILITY CHECK")
print("=" * 80)

# Check with PyTorch
try:
    import torch
    print(f"\nüî• PyTorch version: {torch.__version__}")
    if torch.cuda.is_available():
        print(f"‚úÖ CUDA is available!")
        print(f"   - GPU Device: {torch.cuda.get_device_name(0)}")
        print(f"   - Number of GPUs: {torch.cuda.device_count()}")
        print(f"   - CUDA Version: {torch.version.cuda}")
    else:
        print("‚ö†Ô∏è  No CUDA GPU detected by PyTorch (running on CPU)")
except ImportError:
    print("‚ö†Ô∏è  PyTorch not installed (will check JAX for GPU support)")

# Check with JAX
try:
    import jax
    print(f"\nüî∑ JAX version: {jax.__version__}")
    devices = jax.devices()
    print(f"   - Available devices: {devices}")
    if any('gpu' in str(d).lower() or 'cuda' in str(d).lower() for d in devices):
        print(f"‚úÖ JAX GPU detected!")
    else:
        print(f"‚ö†Ô∏è  JAX running on: {devices[0].device_kind}")
except ImportError:
    print("‚ö†Ô∏è  JAX not yet installed")
except Exception as e:
    print(f"‚ö†Ô∏è  JAX device check error: {e}")

# ----------------------------------------------------------------------------
# 3. Install/Verify Required Packages
# ----------------------------------------------------------------------------
print("\n" + "=" * 80)
print("PACKAGE INSTALLATION CHECK")
print("=" * 80)

required_packages = [
    'jax',
    'jaxlib', 
    'flax',
    'optax',
    'orbax-checkpoint',
    'xarray',
    's3fs',
    'fsspec',
    'gcsfs',
    'numpy',
    'pandas',
    'matplotlib',
    'cartopy',
    'huggingface_hub',
    'datasets'  # For Hugging Face datasets
]

print("\nüì¶ Installing/verifying required packages...\n")
for package in required_packages:
    print(f"Installing {package}...", end=" ")
    result = subprocess.run(
        [sys.executable, "-m", "pip", "install", "-q", package],
        capture_output=True,
        text=True
    )
    if result.returncode == 0:
        print("‚úÖ")
    else:
        print(f"‚ö†Ô∏è  (may already be installed or had issues)")

print("\n‚úÖ Package installation complete!")

# ----------------------------------------------------------------------------
# 4. Import Packages Safely
# ----------------------------------------------------------------------------
print("\n" + "=" * 80)
print("PACKAGE IMPORT VERIFICATION")
print("=" * 80 + "\n")

import_results = {}

# JAX ecosystem
try:
    import jax
    import jax.numpy as jnp
    import_results['jax'] = f"‚úÖ {jax.__version__}"
except ImportError as e:
    import_results['jax'] = f"‚ùå Failed: {e}"

try:
    import jaxlib
    import_results['jaxlib'] = f"‚úÖ {jaxlib.__version__}"
except ImportError as e:
    import_results['jaxlib'] = f"‚ùå Failed: {e}"

try:
    import flax
    from flax import linen as nn
    import_results['flax'] = f"‚úÖ {flax.__version__}"
except ImportError as e:
    import_results['flax'] = f"‚ùå Failed: {e}"

try:
    import optax
    import_results['optax'] = f"‚úÖ {optax.__version__}"
except ImportError as e:
    import_results['optax'] = f"‚ùå Failed: {e}"

try:
    import orbax.checkpoint
    import_results['orbax-checkpoint'] = f"‚úÖ {orbax.checkpoint.__version__}"
except ImportError as e:
    import_results['orbax-checkpoint'] = f"‚ùå Failed: {e}"

# Data science packages
try:
    import numpy as np
    import_results['numpy'] = f"‚úÖ {np.__version__}"
except ImportError as e:
    import_results['numpy'] = f"‚ùå Failed: {e}"

try:
    import pandas as pd
    import_results['pandas'] = f"‚úÖ {pd.__version__}"
except ImportError as e:
    import_results['pandas'] = f"‚ùå Failed: {e}"

try:
    import xarray as xr
    import_results['xarray'] = f"‚úÖ {xr.__version__}"
except ImportError as e:
    import_results['xarray'] = f"‚ùå Failed: {e}"

# Cloud storage packages
try:
    import s3fs
    import_results['s3fs'] = f"‚úÖ {s3fs.__version__}"
except ImportError as e:
    import_results['s3fs'] = f"‚ùå Failed: {e}"

try:
    import fsspec
    import_results['fsspec'] = f"‚úÖ {fsspec.__version__}"
except ImportError as e:
    import_results['fsspec'] = f"‚ùå Failed: {e}"

try:
    import gcsfs
    import_results['gcsfs'] = f"‚úÖ {gcsfs.__version__}"
except ImportError as e:
    import_results['gcsfs'] = f"‚ùå Failed: {e}"

# Visualization packages
try:
    import matplotlib
    import matplotlib.pyplot as plt
    import_results['matplotlib'] = f"‚úÖ {matplotlib.__version__}"
except ImportError as e:
    import_results['matplotlib'] = f"‚ùå Failed: {e}"

try:
    import cartopy
    import_results['cartopy'] = f"‚úÖ {cartopy.__version__}"
except ImportError as e:
    import_results['cartopy'] = f"‚ùå Failed: {e}"

# ML Hub
try:
    import huggingface_hub
    import_results['huggingface_hub'] = f"‚úÖ {huggingface_hub.__version__}"
except ImportError as e:
    import_results['huggingface_hub'] = f"‚ùå Failed: {e}"

try:
    import datasets
    import_results['datasets'] = f"‚úÖ {datasets.__version__}"
except ImportError as e:
    import_results['datasets'] = f"‚ùå Failed: {e}"

# Print results
for package, status in import_results.items():
    print(f"{package:.<30} {status}")

# ----------------------------------------------------------------------------
# 5. OSN Configuration Constants
# ----------------------------------------------------------------------------
print("\n" + "=" * 80)
print("OSN (OPEN STORAGE NETWORK) CONFIGURATION")
print("=" * 80 + "\n")

# OSN endpoint and bucket configuration from LEAP hackathon guide
OSN_ENDPOINT_URL = "https://nyu1.osn.mghpcc.org"
OSN_BUCKET = "leap-pangeo-manual"
HACKATHON_PREFIX = "hackathon-2026"

print(f"üìç OSN_ENDPOINT_URL:  {OSN_ENDPOINT_URL}")
print(f"üìç OSN_BUCKET:        {OSN_BUCKET}")
print(f"üìç HACKATHON_PREFIX:  {HACKATHON_PREFIX}")

print(f"\nüí° Full S3 path: s3://{OSN_BUCKET}/{HACKATHON_PREFIX}/")
print(f"üí° Available datasets: hrrr/, era5_cds/nyc/, corrdiff/")

# Example usage string
print("\nüìù Example usage for reading data from OSN:")
print("```python")
print("import s3fs")
print("import xarray as xr")
print("")
print("# Create S3 filesystem object (no credentials needed for public data)")
print(f"fs = s3fs.S3FileSystem(anon=True, client_kwargs={{'endpoint_url': '{OSN_ENDPOINT_URL}'}})")  
print("")
print("# List available datasets")
print(f"fs.ls('{OSN_BUCKET}/{HACKATHON_PREFIX}/')  # Shows: hrrr/, era5_cds/, corrdiff/")
print("")
print("# Open a Zarr dataset lazily")
print(f"ds = xr.open_zarr(fs.get_mapper('s3://{OSN_BUCKET}/{HACKATHON_PREFIX}/hrrr/your-file.zarr'))")
print("```")

# ----------------------------------------------------------------------------
# 6. Best Practices Reminder
# ----------------------------------------------------------------------------
print("\n" + "=" * 80)
print("‚ö†Ô∏è  BEST PRACTICES REMINDER")
print("=" * 80 + "\n")

print("üö´ DO NOT store large data in your home directory!")
print("   Home directory has limited space (~10GB)")
print("")
print("‚úÖ DO use the leap-scratch directory for outputs:")
print("   /home/jovyan/leap-scratch/<your-name>/")
print("")
print("‚úÖ DO read data directly from cloud storage using s3fs/fsspec")
print("   This avoids downloading large files unnecessarily")
print("")
print("‚úÖ DO use Dask for processing large datasets")
print("   Dask allows lazy loading and parallel processing")
print("")
print("üí° Create your scratch directory if it doesn't exist:")
print("```bash")
print("!mkdir -p /home/jovyan/leap-scratch/$USER")
print("```")

print("\n" + "=" * 80)
print("‚úÖ STARTUP COMPLETE - Ready to hack!")
print("=" * 80)

## Quick Start Examples

### Create Your Scratch Directory
```python
import os
user = os.environ.get('USER', 'default')
scratch_dir = f"/home/jovyan/leap-scratch/{user}"
os.makedirs(scratch_dir, exist_ok=True)
print(f"Your scratch directory: {scratch_dir}")
```

### Access OSN Data
```python
import s3fs
import xarray as xr

# Initialize S3 filesystem
fs = s3fs.S3FileSystem(
    anon=True,  # Anonymous access for public data
    client_kwargs={'endpoint_url': OSN_ENDPOINT_URL}
)

# List available datasets
files = fs.ls(f"{OSN_BUCKET}/{HACKATHON_PREFIX}")
print(files)
```