# Load ML Datasets from Google Drive

This notebook downloads and loads the crash prediction datasets from Google Drive for exploratory data analysis (EDA).

**Datasets Available:**
- Crash-level dataset (train, val, test)
- Segment-level dataset (train, val, test)
- Raw Texas data (crashes, weather, work zones, traffic)

**Setup Required:**
1. Install required packages (see cell below)
2. Authenticate with Google Drive (first run only)
3. Download datasets
4. Load and explore!

## 1. Setup & Installation

In [None]:
# Install required packages (uncomment if needed)
# !pip install --upgrade google-api-python-client google-auth-httplib2 google-auth-oauthlib
# !pip install pandas numpy matplotlib seaborn geopandas

In [None]:
import os
import sys
from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# Set plotting style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)

# Add project root to path
project_root = Path.cwd().parent
sys.path.insert(0, str(project_root))

print("‚úì Imports successful")
print(f"Project root: {project_root}")

## 2. Google Drive Configuration

In [None]:
# Google Drive folder ID (from upload_to_gdrive.py)
FOLDER_ID = '1xVGXbxUFHSdSawo2C9wnmABj15wPEX3A'

# Local download directory
DOWNLOAD_DIR = project_root / 'data' / 'downloaded_from_gdrive'
DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)

print(f"Download directory: {DOWNLOAD_DIR}")
print(f"Google Drive folder: https://drive.google.com/drive/folders/{FOLDER_ID}")

## 3. Google Drive Authentication & Download Functions

In [None]:
try:
    from google.auth.transport.requests import Request
    from google.oauth2.credentials import Credentials
    from google_auth_oauthlib.flow import InstalledAppFlow
    from googleapiclient.discovery import build
    from googleapiclient.http import MediaIoBaseDownload
    from googleapiclient.errors import HttpError
    import io
    
    GDRIVE_AVAILABLE = True
    print("‚úì Google Drive API packages available")
except ImportError:
    GDRIVE_AVAILABLE = False
    print("‚ö†Ô∏è  Google Drive API packages not installed")
    print("   Run: pip install --upgrade google-api-python-client google-auth-httplib2 google-auth-oauthlib")

In [None]:
# Google Drive API scopes
SCOPES = ['https://www.googleapis.com/auth/drive.readonly']

def authenticate_gdrive():
    """
    Authenticate with Google Drive API
    
    Returns:
        service: Google Drive API service object
    """
    creds = None
    token_file = project_root / 'token.json'
    credentials_file = project_root / 'credentials.json'
    
    # Load existing token if available
    if token_file.exists():
        creds = Credentials.from_authorized_user_file(str(token_file), SCOPES)
    
    # If no valid credentials, authenticate
    if not creds or not creds.valid:
        if creds and creds.expired and creds.refresh_token:
            print("üîÑ Refreshing expired credentials...")
            creds.refresh(Request())
        else:
            if not credentials_file.exists():
                print(f"‚ùå Error: {credentials_file} not found")
                print("\nPlease set up Google Drive API credentials:")
                print("1. Go to https://console.cloud.google.com/")
                print("2. Create/select project")
                print("3. Enable Google Drive API")
                print("4. Create OAuth 2.0 credentials (Desktop app)")
                print(f"5. Download as credentials.json to {project_root}")
                return None
            
            print("üîê Authenticating with Google Drive...")
            print("   (Browser will open for authorization)")
            flow = InstalledAppFlow.from_client_secrets_file(
                str(credentials_file), SCOPES
            )
            creds = flow.run_local_server(port=0)
        
        # Save credentials for future use
        with open(token_file, 'w') as token:
            token.write(creds.to_json())
        print("‚úÖ Credentials saved")
    
    # Build service
    service = build('drive', 'v3', credentials=creds)
    return service

def list_files_in_folder(service, folder_id, verbose=True):
    """
    List all files in a Google Drive folder recursively
    
    Returns:
        List of dicts with file info
    """
    try:
        files = []
        page_token = None
        
        while True:
            query = f"'{folder_id}' in parents and trashed=false"
            results = service.files().list(
                q=query,
                spaces='drive',
                fields='nextPageToken, files(id, name, mimeType, size)',
                pageToken=page_token
            ).execute()
            
            items = results.get('files', [])
            files.extend(items)
            
            page_token = results.get('nextPageToken')
            if not page_token:
                break
        
        # Recursively get files from subfolders
        all_files = []
        for item in files:
            if item['mimeType'] == 'application/vnd.google-apps.folder':
                # It's a folder - recurse
                subfolder_files = list_files_in_folder(service, item['id'], verbose=False)
                # Add folder name as prefix
                for f in subfolder_files:
                    f['path'] = f"{item['name']}/{f.get('path', f['name'])}"
                all_files.extend(subfolder_files)
            else:
                # It's a file
                item['path'] = item['name']
                all_files.append(item)
        
        if verbose:
            print(f"‚úì Found {len(all_files)} files")
        
        return all_files
        
    except HttpError as error:
        print(f"‚ùå Error listing files: {error}")
        return []

def download_file(service, file_id, file_name, dest_path, verbose=True):
    """
    Download a file from Google Drive
    """
    try:
        # Create parent directories
        dest_path.parent.mkdir(parents=True, exist_ok=True)
        
        # Download file
        request = service.files().get_media(fileId=file_id)
        
        with open(dest_path, 'wb') as f:
            downloader = MediaIoBaseDownload(f, request)
            done = False
            while not done:
                status, done = downloader.next_chunk()
                if verbose and status:
                    progress = int(status.progress() * 100)
                    print(f"  Downloading {file_name}: {progress}%", end='\r')
        
        if verbose:
            file_size = dest_path.stat().st_size / 1024 / 1024
            print(f"  ‚úì Downloaded {file_name} ({file_size:.1f} MB)")
        
        return True
        
    except HttpError as error:
        print(f"  ‚ùå Error downloading {file_name}: {error}")
        return False

print("‚úì Functions defined")

## 4. Download Datasets from Google Drive

In [None]:
# Authenticate
if GDRIVE_AVAILABLE:
    service = authenticate_gdrive()
    
    if service:
        print("\n‚úÖ Successfully authenticated with Google Drive")
    else:
        print("\n‚ö†Ô∏è  Authentication failed - please check credentials.json")
else:
    print("\n‚ö†Ô∏è  Google Drive API not available")
    service = None

In [None]:
# List available files
if service:
    print("\n" + "="*70)
    print("üìÅ Listing files in Google Drive folder...")
    print("="*70)
    
    files = list_files_in_folder(service, FOLDER_ID)
    
    # Show file structure
    print("\nAvailable files:")
    for f in sorted(files, key=lambda x: x['path']):
        size_mb = int(f.get('size', 0)) / 1024 / 1024 if 'size' in f else 0
        print(f"  {f['path']:<60} ({size_mb:>6.1f} MB)")
else:
    print("\n‚ö†Ô∏è  Skipping file listing (authentication required)")

In [None]:
# Download datasets
if service:
    print("\n" + "="*70)
    print("üì• Downloading datasets...")
    print("="*70)
    
    downloaded = 0
    for file_info in files:
        # Skip folders
        if file_info['mimeType'] == 'application/vnd.google-apps.folder':
            continue
        
        # Determine local path
        local_path = DOWNLOAD_DIR / file_info['path']
        
        # Skip if already exists
        if local_path.exists():
            print(f"  ‚è≠Ô∏è  Skipping {file_info['path']} (already exists)")
            continue
        
        # Download
        success = download_file(service, file_info['id'], file_info['name'], local_path)
        if success:
            downloaded += 1
    
    print(f"\n‚úÖ Downloaded {downloaded} new files")
    print(f"   Total files in {DOWNLOAD_DIR}: {len(list(DOWNLOAD_DIR.rglob('*')))}")
else:
    print("\n‚ö†Ô∏è  Skipping download (authentication required)")
    print("\nAlternative: Manually download from:")
    print(f"  https://drive.google.com/drive/folders/{FOLDER_ID}")
    print(f"  Save to: {DOWNLOAD_DIR}")

## 5. Load Datasets into Pandas

In [None]:
# Define dataset paths
CRASH_LEVEL_DIR = DOWNLOAD_DIR / 'crash_level'
SEGMENT_LEVEL_DIR = DOWNLOAD_DIR / 'segment_level'

# Load crash-level datasets
print("\n" + "="*70)
print("üìä Loading CRASH-LEVEL datasets...")
print("="*70)

crash_train = None
crash_val = None
crash_test = None

if (CRASH_LEVEL_DIR / 'train.csv').exists():
    crash_train = pd.read_csv(CRASH_LEVEL_DIR / 'train.csv')
    print(f"‚úì Loaded crash_train: {crash_train.shape[0]:,} rows √ó {crash_train.shape[1]} columns")
else:
    print("‚ö†Ô∏è  train.csv not found")

if (CRASH_LEVEL_DIR / 'val.csv').exists():
    crash_val = pd.read_csv(CRASH_LEVEL_DIR / 'val.csv')
    print(f"‚úì Loaded crash_val: {crash_val.shape[0]:,} rows √ó {crash_val.shape[1]} columns")
else:
    print("‚ö†Ô∏è  val.csv not found")

if (CRASH_LEVEL_DIR / 'test.csv').exists():
    crash_test = pd.read_csv(CRASH_LEVEL_DIR / 'test.csv')
    print(f"‚úì Loaded crash_test: {crash_test.shape[0]:,} rows √ó {crash_test.shape[1]} columns")
else:
    print("‚ö†Ô∏è  test.csv not found")

In [None]:
# Load segment-level datasets
print("\n" + "="*70)
print("üìä Loading SEGMENT-LEVEL datasets...")
print("="*70)

segment_train = None
segment_val = None
segment_test = None

if (SEGMENT_LEVEL_DIR / 'train.csv').exists():
    segment_train = pd.read_csv(SEGMENT_LEVEL_DIR / 'train.csv')
    print(f"‚úì Loaded segment_train: {segment_train.shape[0]:,} rows √ó {segment_train.shape[1]} columns")
else:
    print("‚ö†Ô∏è  train.csv not found")

if (SEGMENT_LEVEL_DIR / 'val.csv').exists():
    segment_val = pd.read_csv(SEGMENT_LEVEL_DIR / 'val.csv')
    print(f"‚úì Loaded segment_val: {segment_val.shape[0]:,} rows √ó {segment_val.shape[1]} columns")
else:
    print("‚ö†Ô∏è  val.csv not found")

if (SEGMENT_LEVEL_DIR / 'test.csv').exists():
    segment_test = pd.read_csv(SEGMENT_LEVEL_DIR / 'test.csv')
    print(f"‚úì Loaded segment_test: {segment_test.shape[0]:,} rows √ó {segment_test.shape[1]} columns")
else:
    print("‚ö†Ô∏è  test.csv not found")

## 6. Quick Data Overview

In [None]:
# Display crash-level training data overview
if crash_train is not None:
    print("\n" + "="*70)
    print("üìã CRASH-LEVEL TRAINING DATA OVERVIEW")
    print("="*70)
    
    print(f"\nShape: {crash_train.shape[0]:,} rows √ó {crash_train.shape[1]} columns")
    
    # Target variable distribution
    if 'high_severity' in crash_train.columns:
        print("\nTarget Variable (high_severity):")
        print(crash_train['high_severity'].value_counts())
        print(f"  High severity rate: {crash_train['high_severity'].mean()*100:.1f}%")
    
    # Temporal split
    if 'Start_Time' in crash_train.columns:
        crash_train['year'] = pd.to_datetime(crash_train['Start_Time']).dt.year
        print("\nTemporal Distribution:")
        print(crash_train['year'].value_counts().sort_index())
    
    # Sample data
    print("\nFirst 5 rows:")
    display(crash_train.head())
    
    print("\nData Types:")
    print(crash_train.dtypes.value_counts())
    
    print("\nMissing Values (top 10):")
    missing = crash_train.isnull().sum().sort_values(ascending=False).head(10)
    missing_pct = (missing / len(crash_train) * 100).round(1)
    print(pd.DataFrame({'Missing': missing, 'Percent': missing_pct}))

## 7. Feature Categories

In [None]:
# Categorize features by source/type
if crash_train is not None:
    cols = crash_train.columns.tolist()
    
    feature_categories = {
        'Target': [c for c in cols if 'severity' in c.lower()],
        'Temporal': [c for c in cols if any(x in c.lower() for x in ['time', 'hour', 'day', 'month', 'year', 'date'])],
        'Location': [c for c in cols if any(x in c.lower() for x in ['lat', 'lng', 'lon', 'city', 'county', 'state', 'street', 'zipcode'])],
        'Weather': [c for c in cols if any(x in c.lower() for x in ['weather', 'temp', 'wind', 'precip', 'humidity', 'pressure', 'visibility'])],
        'Road (OSMnx)': [c for c in cols if c.startswith('osmnx_') or any(x in c.lower() for x in ['highway', 'lanes', 'bridge', 'tunnel', 'oneway'])],
        'Road (HPMS)': [c for c in cols if c.startswith('hpms_')],
        'Traffic': [c for c in cols if any(x in c.lower() for x in ['aadt', 'traffic'])],
        'Work Zones': [c for c in cols if 'wz_' in c or 'work_zone' in c.lower()],
        'Lighting': [c for c in cols if 'light' in c.lower()],
        'Other': []
    }
    
    # Assign uncategorized columns to 'Other'
    categorized = set()
    for cat_cols in feature_categories.values():
        categorized.update(cat_cols)
    feature_categories['Other'] = [c for c in cols if c not in categorized]
    
    print("\n" + "="*70)
    print("üìÇ FEATURE CATEGORIES")
    print("="*70)
    
    for category, features in feature_categories.items():
        if features:
            print(f"\n{category} ({len(features)} features):")
            for f in features[:10]:  # Show first 10
                print(f"  - {f}")
            if len(features) > 10:
                print(f"  ... and {len(features) - 10} more")

## 8. Starter EDA Code

Below are some starter code snippets for exploratory data analysis.

In [None]:
# Target variable distribution
if crash_train is not None and 'high_severity' in crash_train.columns:
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Count plot
    crash_train['high_severity'].value_counts().plot(kind='bar', ax=axes[0])
    axes[0].set_title('Target Variable Distribution (high_severity)', fontsize=14, fontweight='bold')
    axes[0].set_xlabel('High Severity')
    axes[0].set_ylabel('Count')
    axes[0].set_xticklabels(['No (0)', 'Yes (1)'], rotation=0)
    
    # Add percentages
    for i, v in enumerate(crash_train['high_severity'].value_counts().values):
        pct = v / len(crash_train) * 100
        axes[0].text(i, v + 1000, f'{v:,}\n({pct:.1f}%)', ha='center', va='bottom', fontweight='bold')
    
    # Temporal trend
    if 'year' in crash_train.columns:
        yearly_severity = crash_train.groupby('year')['high_severity'].mean() * 100
        yearly_severity.plot(kind='line', marker='o', ax=axes[1])
        axes[1].set_title('High Severity Rate by Year', fontsize=14, fontweight='bold')
        axes[1].set_xlabel('Year')
        axes[1].set_ylabel('High Severity Rate (%)')
        axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

In [None]:
# Weather conditions vs severity
if crash_train is not None and 'Weather_Condition' in crash_train.columns and 'high_severity' in crash_train.columns:
    # Top 10 weather conditions
    top_weather = crash_train['Weather_Condition'].value_counts().head(10).index
    
    weather_severity = crash_train[crash_train['Weather_Condition'].isin(top_weather)].groupby('Weather_Condition')['high_severity'].agg(['mean', 'count'])
    weather_severity = weather_severity.sort_values('mean', ascending=False)
    
    fig, axes = plt.subplots(1, 2, figsize=(16, 5))
    
    # Severity rate by weather
    (weather_severity['mean'] * 100).plot(kind='barh', ax=axes[0])
    axes[0].set_title('High Severity Rate by Weather Condition', fontsize=14, fontweight='bold')
    axes[0].set_xlabel('High Severity Rate (%)')
    axes[0].set_ylabel('Weather Condition')
    
    # Count by weather
    weather_severity['count'].sort_values().plot(kind='barh', ax=axes[1])
    axes[1].set_title('Crash Count by Weather Condition', fontsize=14, fontweight='bold')
    axes[1].set_xlabel('Number of Crashes')
    axes[1].set_ylabel('Weather Condition')
    
    plt.tight_layout()
    plt.show()

In [None]:
# Feature completeness by source
if crash_train is not None:
    completeness_data = []
    
    for category, features in feature_categories.items():
        if features and category not in ['Target', 'Other']:
            for f in features:
                if f in crash_train.columns:
                    completeness = crash_train[f].notna().mean() * 100
                    completeness_data.append({
                        'Category': category,
                        'Feature': f,
                        'Completeness': completeness
                    })
    
    if completeness_data:
        completeness_df = pd.DataFrame(completeness_data)
        
        # Category averages
        category_avg = completeness_df.groupby('Category')['Completeness'].mean().sort_values()
        
        plt.figure(figsize=(10, 6))
        category_avg.plot(kind='barh')
        plt.title('Average Feature Completeness by Data Source', fontsize=14, fontweight='bold')
        plt.xlabel('Completeness (%)')
        plt.ylabel('Data Source')
        plt.xlim(0, 100)
        plt.grid(axis='x', alpha=0.3)
        
        # Add value labels
        for i, v in enumerate(category_avg.values):
            plt.text(v + 1, i, f'{v:.1f}%', va='center')
        
        plt.tight_layout()
        plt.show()

In [None]:
# HPMS features vs severity
if crash_train is not None and 'high_severity' in crash_train.columns:
    hpms_features = [c for c in crash_train.columns if c.startswith('hpms_')]
    
    if hpms_features:
        print("\n" + "="*70)
        print("üõ£Ô∏è  HPMS ROAD FEATURES vs SEVERITY")
        print("="*70)
        
        # Speed limit
        if 'hpms_speed_limit' in crash_train.columns:
            fig, axes = plt.subplots(1, 2, figsize=(14, 5))
            
            # Distribution
            crash_train.boxplot(column='hpms_speed_limit', by='high_severity', ax=axes[0])
            axes[0].set_title('Speed Limit by Severity Level', fontsize=12, fontweight='bold')
            axes[0].set_xlabel('High Severity')
            axes[0].set_ylabel('Speed Limit (mph)')
            plt.sca(axes[0])
            plt.xticks([1, 2], ['No (0)', 'Yes (1)'])
            
            # Severity rate by speed bins
            speed_bins = [0, 30, 45, 60, 75, 100]
            crash_train['speed_bin'] = pd.cut(crash_train['hpms_speed_limit'], bins=speed_bins)
            speed_severity = crash_train.groupby('speed_bin')['high_severity'].agg(['mean', 'count'])
            
            (speed_severity['mean'] * 100).plot(kind='bar', ax=axes[1])
            axes[1].set_title('Severity Rate by Speed Limit Range', fontsize=12, fontweight='bold')
            axes[1].set_xlabel('Speed Limit Range (mph)')
            axes[1].set_ylabel('High Severity Rate (%)')
            axes[1].set_xticklabels(axes[1].get_xticklabels(), rotation=45)
            
            # Add counts
            for i, (idx, row) in enumerate(speed_severity.iterrows()):
                axes[1].text(i, row['mean']*100 + 0.5, f"n={int(row['count']):,}", 
                           ha='center', va='bottom', fontsize=9)
            
            plt.tight_layout()
            plt.show()
            
            crash_train.drop(columns=['speed_bin'], inplace=True)

## 9. Data Dictionary

Load and display the data dictionary if available.

In [None]:
# Load data dictionary
data_dict_path = CRASH_LEVEL_DIR / 'DATA_DICTIONARY.md'

if data_dict_path.exists():
    print("\n" + "="*70)
    print("üìñ DATA DICTIONARY")
    print("="*70)
    
    with open(data_dict_path, 'r') as f:
        print(f.read())
else:
    print("\n‚ö†Ô∏è  DATA_DICTIONARY.md not found")
    print(f"   Expected at: {data_dict_path}")

## 10. Next Steps for EDA

**Suggested analyses to explore:**

1. **Temporal Patterns**
   - Hour of day, day of week, month, season
   - Holiday effects
   - Temporal trends (2016-2023)

2. **Spatial Patterns**
   - City/county differences
   - Urban vs rural
   - Geographic clustering

3. **Weather Impact**
   - Weather conditions vs severity
   - Temperature, precipitation, visibility effects
   - Adverse weather combinations

4. **Road Characteristics**
   - Highway type (HPMS f_system)
   - Speed limit ranges
   - Lane counts
   - Pavement condition (IRI)
   - Traffic volume (AADT)

5. **Work Zone Effects**
   - Crashes in/near work zones
   - Work zone density effects

6. **Feature Correlations**
   - Correlation matrix for numeric features
   - Feature importance via random forest

7. **Missing Data Analysis**
   - Patterns in missingness
   - Impact on modeling

8. **Class Balance**
   - High severity rate across different segments
   - Potential need for resampling

In [None]:
# Summary of loaded datasets
print("\n" + "="*70)
print("‚úÖ DATASETS READY FOR ANALYSIS")
print("="*70)
print("\nAvailable DataFrames:")
print(f"  - crash_train: {crash_train.shape if crash_train is not None else 'Not loaded'}")
print(f"  - crash_val: {crash_val.shape if crash_val is not None else 'Not loaded'}")
print(f"  - crash_test: {crash_test.shape if crash_test is not None else 'Not loaded'}")
print(f"  - segment_train: {segment_train.shape if segment_train is not None else 'Not loaded'}")
print(f"  - segment_val: {segment_val.shape if segment_val is not None else 'Not loaded'}")
print(f"  - segment_test: {segment_test.shape if segment_test is not None else 'Not loaded'}")
print("\nHappy analyzing! üéâ")