In [19]:
def verify_data_files():
    """Verify all required data files are present."""
    required_files = [
        'Train.csv',
        'Test.csv',
        'composite_images.npz',
        'SampleSubmission.csv'
    ]
    
    missing_files = []
    for file in required_files:
        if not os.path.exists(os.path.join(BASE_PATH, file)):
            missing_files.append(file)
    
    if missing_files:
        raise FileNotFoundError(
            f"The following files are missing in {BASE_PATH}: {', '.join(missing_files)}"
        )

In [20]:
# Improved Flood Detection Model for South Africa

import os
import math
import pickle
from pathlib import Path
from typing import Any, Dict, Tuple, Optional
from collections.abc import Callable, Sequence
from functools import partial

# Data processing
import numpy as np
import pandas as pd
from scipy.ndimage import gaussian_filter1d
from sklearn.model_selection import train_test_split

# Deep Learning
import jax
import jax.numpy as jnp
import flax
from flax import linen as nn
from flax.training import train_state
import optax

# Visualization
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

# Configuration
class Config:
    """Model and training configuration."""
    # Paths
    BASE_PATH = Path('c:/Users/CraigParker/OneDrive - Wits Health Consortium/PHR PC/Downloads/Joburg/Zindi_comp')
    
    # Data
    SEED = 42
    VALID_SIZE = 0.1
    BAND_NAMES = ('B2', 'B3', 'B4', 'B8', 'B11', 'slope')
    IMG_SHAPE = (128, 128, len(BAND_NAMES))
    
    # Model
    HIDDEN_DIM = 256
    NUM_HEADS = 4
    NUM_LAYERS = 6
    DROPOUT = 0.1
    
    # Training
    BATCH_SIZE = 32
    NUM_EPOCHS = 100
    LEARNING_RATE = 1e-4
    WEIGHT_DECAY = 0.01
    
    # Time series preprocessing
    WINDOW_SIZES = [3, 7, 14, 30]

class DataProcessor:
    """Handles data loading and preprocessing."""
    def __init__(self, config: Config):
        self.config = config
        self.rng = np.random.default_rng(config.SEED)
    
    def load_raw_data(self) -> Tuple[pd.DataFrame, pd.DataFrame, Dict]:
        """Load raw data files."""
        # Load CSV files
        train_df = pd.read_csv(self.config.BASE_PATH / 'Train.csv')
        test_df = pd.read_csv(self.config.BASE_PATH / 'Test.csv')
        
        # Load image data
        images = np.load(self.config.BASE_PATH / 'composite_images.npz')
        
        return train_df, test_df, images
    
    def preprocess_time_series(self, data: np.ndarray) -> np.ndarray:
        """Process time series data with rolling features."""
        # Ensure 2D array
        if data.ndim == 1:
            data = data[np.newaxis, :]
            
        features = []
        # Original data
        features.append(data)
        
        # Rolling statistics
        for window in self.config.WINDOW_SIZES:
            # Rolling mean
            rolling_mean = np.array([
                np.convolve(row, np.ones(window)/window, mode='same')
                for row in data
            ])
            features.append(rolling_mean)
            
            # Rolling max
            rolling_max = np.array([
                np.maximum.accumulate(row)
                for row in rolling_mean
            ])
            features.append(rolling_max)
            
            # Rolling std
            rolling_std = np.array([
                [np.std(row[max(0, i-window):i+1]) for i in range(len(row))]
                for row in data
            ])
            features.append(rolling_std)
        
        # Cumulative sum
        features.append(np.cumsum(data, axis=1))
        
        # Smoothed signal
        features.append(gaussian_filter1d(data, sigma=2.0, axis=1))
        
        # Combine features
        combined = np.stack(features, axis=-1)
        
        # Normalize
        mean = np.mean(combined, axis=(0, 1), keepdims=True)
        std = np.std(combined, axis=(0, 1), keepdims=True) + 1e-8
        return ((combined - mean) / std).astype(np.float32)
    
    def preprocess_image(self, image: np.ndarray, augment: bool = False) -> np.ndarray:
        """Process satellite imagery."""
        # Split bands
        spectral = image[..., :-1].astype(np.float32)
        slope = image[..., -1:].astype(np.float32)
        
        # Normalize spectral bands
        spectral = (spectral - 1250) / 500
        
        # Convert slope to radians
        slope = (slope / np.iinfo(np.uint16).max * (np.pi / 2.0))
        
        # Combine processed bands
        processed = np.concatenate([spectral, slope], axis=-1)
        
        if augment and self.rng.random() > 0.5:
            # Random flip
            if self.rng.random() > 0.5:
                processed = np.flip(processed, axis=0)
            if self.rng.random() > 0.5:
                processed = np.flip(processed, axis=1)
            # Random 90-degree rotation
            k = self.rng.integers(4)
            processed = np.rot90(processed, k=k)
            
        return processed
    
    def prepare_datasets(self) -> Tuple[Dict, Dict, Dict]:
        """Prepare train, validation and test datasets."""
        # Load data
        train_df, test_df, images = self.load_raw_data()
        
        # Process event IDs
        train_df['event_id'] = train_df['event_id'].apply(lambda x: '_'.join(x.split('_')[:2]))
        test_df['event_id'] = test_df['event_id'].apply(lambda x: '_'.join(x.split('_')[:2]))
        
        # Ensure required columns are present
        required_columns = {'event_id', 'precipitation'}
        if not required_columns.issubset(train_df.columns):
            missing_cols = required_columns - set(train_df.columns)
            raise KeyError(f"Missing columns in train_df: {missing_cols}")
        if not required_columns.issubset(test_df.columns):
            missing_cols = required_columns - set(test_df.columns)
            raise KeyError(f"Missing columns in test_df: {missing_cols}")
        
        # Split train/validation
        train_events, valid_events = train_test_split(
            train_df['event_id'].unique(),
            test_size=self.config.VALID_SIZE,
            random_state=self.config.SEED
        )
        
        # Prepare datasets
        def prepare_set(events, df):
            mask = df['event_id'].isin(events)
            subset = df[mask]
            
            # Time series data
            ts = subset.pivot_table(
                index='event_id',
                columns=subset.groupby('event_id').cumcount(),
                values='precipitation'
            ).to_numpy()
            
            # Images
            imgs = np.stack([images[eid] for eid in events])
            
            # Labels (if available)
            labels = None
            if 'label' in subset.columns and 'event_t' in subset.columns:
                labels = subset.pivot(
                    index='event_id',
                    columns='event_t',
                    values='label'
                ).to_numpy()
            
            return {
                'timeseries': self.preprocess_time_series(ts),
                'images': imgs,
                'labels': labels
            }
        
        train_data = prepare_set(train_events, train_df)
        valid_data = prepare_set(valid_events, train_df)
        test_data = prepare_set(test_df['event_id'].unique(), test_df)
        
        return train_data, valid_data, test_data

def create_model(config: Config):
    """Create the flood detection model."""
    # Model architecture code will go here
    pass

def main():
    """Main execution function."""
    config = Config()
    
    # Initialize data processor
    processor = DataProcessor(config)
    
    print("Preparing datasets...")
    train_data, valid_data, test_data = processor.prepare_datasets()
    
    print("Dataset shapes:")
    print(f"Train: {train_data['timeseries'].shape}, {train_data['images'].shape}")
    print(f"Valid: {valid_data['timeseries'].shape}, {valid_data['images'].shape}")
    print(f"Test: {test_data['timeseries'].shape}, {test_data['images'].shape}")
    
    # Model training code will go here
    
if __name__ == "__main__":
    main()

Preparing datasets...
Dataset shapes:
Train: (606, 730, 15), (606, 128, 128, 6)
Valid: (68, 730, 15), (68, 128, 128, 6)
Test: (224, 730, 15), (224, 128, 128, 6)


In [21]:
import os
import math
import pickle
from pathlib import Path
from typing import Any, Dict, Tuple
import json

# Data processing
import numpy as np
import pandas as pd
from scipy.ndimage import gaussian_filter1d
from sklearn.model_selection import train_test_split

# Deep Learning
import jax
import jax.numpy as jnp
import flax
from flax import linen as nn
from flax.training import train_state
import optax

# Visualization
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

# Configuration
class Config:
    def __init__(self):
        # Paths
        self.base_path = Path(r'C:\Users\CraigParker\OneDrive - Wits Health Consortium\PHR PC\Downloads\Joburg\Zindi_comp')
        
        # Data
        self.seed = 42
        self.valid_size = 0.1
        self.band_names = ('B2', 'B3', 'B4', 'B8', 'B11', 'slope')
        
        # Model
        self.patch_size = 16
        self.hidden_dim = 512
        self.num_layers = 8
        self.num_heads = 8
        self.dropout = 0.1
        
        # Training
        self.batch_size = 32
        self.num_epochs = 100
        self.learning_rate = 1e-4
        self.weight_decay = 0.01
        
        # Data processing
        self.required_files = ['Train.csv', 'Test.csv', 'composite_images.npz', 'SampleSubmission.csv']
        self.img_shape = (128, 128, len(self.band_names))

class DataProcessor:
    def __init__(self, config: Config):
        self.config = config
        self.rng = np.random.default_rng(config.seed)
    
    def load_data(self) -> Tuple[pd.DataFrame, pd.DataFrame, Dict]:
        """Load and preprocess all data."""
        # Load CSV files
        data = pd.read_csv(self.config.base_path / 'Train.csv')
        data_test = pd.read_csv(self.config.base_path / 'Test.csv')
        
        # Process event IDs and create time steps
        for df in [data, data_test]:
            df['event_id'] = df['event_id'].apply(lambda x: '_'.join(x.split('_')[:2]))
            df['event_t'] = df.groupby('event_id').cumcount()
        
        # Load images
        images = np.load(self.config.base_path / 'composite_images.npz')
        
        return data, data_test, images
    
    def preprocess_timeseries(self, data: np.ndarray) -> np.ndarray:
        """Process time series data."""
        features = [data]
        
        # Add rolling statistics
        windows = [3, 7, 14, 30]
        for window in windows:
            # Mean
            rolling_mean = np.array([
                np.convolve(row, np.ones(window)/window, mode='same')
                for row in data
            ])
            features.append(rolling_mean)
            
            # Max
            rolling_max = np.array([
                np.maximum.accumulate(row) for row in rolling_mean
            ])
            features.append(rolling_max)
        
        # Add cumulative and smoothed
        features.append(np.cumsum(data, axis=1))
        features.append(gaussian_filter1d(data, sigma=2.0, axis=1))
        
        # Combine and normalize
        combined = np.stack(features, axis=-1)
        mean = np.mean(combined, axis=(0, 1), keepdims=True)
        std = np.std(combined, axis=(0, 1), keepdims=True) + 1e-8
        return ((combined - mean) / std).astype(np.float32)
    
    def preprocess_image(self, image: np.ndarray, augment: bool = False) -> np.ndarray:
        """Process satellite imagery."""
        # Split bands
        spectral = image[..., :-1].astype(np.float32)
        slope = image[..., -1:].astype(np.float32)
        
        # Normalize
        spectral = (spectral - 1250) / 500
        slope = slope / np.iinfo(np.uint16).max * (np.pi / 2.0)
        
        # Combine
        processed = np.concatenate([spectral, slope], axis=-1)
        
        if augment and self.rng.random() > 0.5:
            if self.rng.random() > 0.5:
                processed = np.flip(processed, axis=0)
            if self.rng.random() > 0.5:
                processed = np.flip(processed, axis=1)
            processed = np.rot90(processed, k=self.rng.integers(4))
        
        return processed
    
    def prepare_datasets(self) -> Tuple[Dict, Dict, Dict]:
        """Prepare complete datasets for training."""
        # Load data
        data, data_test, images = self.load_data()
        
        # Create train/validation split
        train_events, valid_events = train_test_split(
            data['event_id'].unique(),
            test_size=self.config.valid_size,
            random_state=self.config.seed
        )
        
        def prepare_set(events, df):
            subset = df[df['event_id'].isin(events)]
            
            # Time series
            ts = subset.pivot(
                index='event_id',
                columns='event_t',
                values='precipitation'
            ).fillna(0).to_numpy()
            
            # Images
            imgs = np.stack([images[eid] for eid in events])
            
            # Labels if available
            labels = None
            if 'label' in df.columns:
                labels = subset.pivot(
                    index='event_id',
                    columns='event_t',
                    values='label'
                ).fillna(0).to_numpy()
            
            return {
                'timeseries': self.preprocess_timeseries(ts),
                'images': imgs,
                'labels': labels
            }
        
        return (
            prepare_set(train_events, data),
            prepare_set(valid_events, data),
            prepare_set(data_test['event_id'].unique(), data_test)
        )

class FloodDetectionModel(nn.Module):
    """Combined model for flood detection."""
    config: Config
    
    @nn.compact
    def __call__(self, inputs, training: bool = True):
        timeseries, images = inputs
        B, T, F = timeseries.shape
        
        # Process time series
        x_ts = nn.Dense(self.config.hidden_dim)(timeseries)
        
        # Add positional encoding
        position = jnp.arange(T)[None, :, None]
        div_term = jnp.exp(
            jnp.arange(0, self.config.hidden_dim, 2) * 
            (-math.log(10000.0) / self.config.hidden_dim)
        )
        pos_enc = jnp.zeros((1, T, self.config.hidden_dim))
        pos_enc = pos_enc.at[:, :, 0::2].set(jnp.sin(position * div_term))
        pos_enc = pos_enc.at[:, :, 1::2].set(jnp.cos(position * div_term))
        x_ts = x_ts + pos_enc
        
        # Transformer layers for time series
        for _ in range(self.config.num_layers):
            y = nn.LayerNorm()(x_ts)
            y = nn.MultiHeadDotProductAttention(
                num_heads=self.config.num_heads
            )(y, y, deterministic=not training)
            x_ts = x_ts + y
        
        # Process images
        x_img = nn.Conv(
            features=self.config.hidden_dim,
            kernel_size=(self.config.patch_size, self.config.patch_size),
            strides=(self.config.patch_size, self.config.patch_size)
        )(images)
        
        num_patches = (self.config.img_shape[0] // self.config.patch_size) ** 2
        x_img = x_img.reshape(B, num_patches, self.config.hidden_dim)
        
        # Transformer layers for images
        for _ in range(self.config.num_layers):
            y = nn.LayerNorm()(x_img)
            y = nn.MultiHeadDotProductAttention(
                num_heads=self.config.num_heads
            )(y, y, deterministic=not training)
            x_img = x_img + y
        
        # Global average pooling for images
        x_img = jnp.mean(x_img, axis=1)
        
        # Combine features
        x_img = jnp.expand_dims(x_img, axis=1)
        x_img = jnp.tile(x_img, (1, T, 1))
        
        x = jnp.concatenate([x_ts, x_img], axis=-1)
        
        # Output projection
        x = nn.Dense(1)(x)
        return jnp.squeeze(x, axis=-1)

class Trainer:
    def __init__(self, config: Config, model: nn.Module, train_data: Dict, valid_data: Dict):
        self.config = config
        self.model = model
        self.train_data = train_data
        self.valid_data = valid_data
        
        # Initialize training state
        rng = jax.random.PRNGKey(config.seed)
        dummy_batch = (
            jnp.ones((1, 730, train_data['timeseries'].shape[-1])),
            jnp.ones((1,) + config.img_shape)
        )
        variables = model.init(rng, dummy_batch)
        
        # Create optimizer
        tx = optax.chain(
            optax.clip_by_global_norm(1.0),
            optax.adamw(
                learning_rate=optax.cosine_decay_schedule(
                    init_value=config.learning_rate,
                    decay_steps=config.num_epochs,
                    alpha=0.1
                ),
                weight_decay=config.weight_decay
            )
        )
        
        self.state = train_state.TrainState.create(
            apply_fn=model.apply,
            params=variables['params'],
            tx=tx
        )
    
    def train(self):
        """Training loop with validation."""
        for epoch in range(self.config.num_epochs):
            # Training
            with tqdm(range(0, len(self.train_data['timeseries']), self.config.batch_size),
                     desc=f"Epoch {epoch+1}/{self.config.num_epochs}") as pbar:
                
                for i in pbar:
                    batch_idx = slice(i, i + self.config.batch_size)
                    batch = {
                        'timeseries': self.train_data['timeseries'][batch_idx],
                        'images': self.train_data['images'][batch_idx],
                        'labels': self.train_data['labels'][batch_idx]
                    }
                    
                    # Training step
                    self.state, metrics = self.train_step(batch)
                    pbar.set_postfix({'loss': f"{metrics['loss']:.4f}",
                                    'acc': f"{metrics['accuracy']:.4f}"})
            
            # Validation
            valid_metrics = self.evaluate()
            print(f"\nValidation - Loss: {valid_metrics['loss']:.4f}, "
                  f"Acc: {valid_metrics['accuracy']:.4f}")
    
    @partial(jax.jit, static_argnums=(0,))
    def train_step(self, batch):
        """Single training step."""
        def loss_fn(params):
            logits = self.state.apply_fn(
                {'params': params},
                (batch['timeseries'], batch['images'])
            )
            loss = optax.sigmoid_binary_cross_entropy(logits, batch['labels']).mean()
            return loss, logits
        
        (loss, logits), grads = jax.value_and_grad(loss_fn, has_aux=True)(self.state.params)
        state = self.state.apply_gradients(grads=grads)
        
        # Compute metrics
        metrics = {
            'loss': loss,
            'accuracy': jnp.mean((jax.nn.sigmoid(logits) > 0.5) == batch['labels'])
        }
        
        return state, metrics
    
    def evaluate(self):
        """Evaluate on validation set."""
        metrics_list = []
        
        for i in range(0, len(self.valid_data['timeseries']), self.config.batch_size):
            batch_idx = slice(i, i + self.config.batch_size)
            batch = {
                'timeseries': self.valid_data['timeseries'][batch_idx],
                'images': self.valid_data['images'][batch_idx],
                'labels': self.valid_data['labels'][batch_idx]
            }
            
            logits = self.state.apply_fn(
                {'params': self.state.params},
                (batch['timeseries'], batch['images'])
            )
            
            loss = optax.sigmoid_binary_cross_entropy(logits, batch['labels']).mean()
            accuracy = jnp.mean((jax.nn.sigmoid(logits) > 0.5) == batch['labels'])
            
            metrics_list.append({'loss': loss, 'accuracy': accuracy})
        
        # Average metrics
        return {
            k: float(np.mean([m[k] for m in metrics_list]))
            for k in metrics_list[0].keys()
        }

def main():
    """Main execution function."""
    try:
        # Initialize config and data processor
        config = Config()
        processor = DataProcessor(config)
        
        print("Preparing datasets...")
        train_data, valid_data, test_data = processor.prepare_datasets()
        
        print("\nDataset shapes:")
        print(f"Train: {train_data['timeseries'].shape}")
        print(f"Valid: {valid_data['timeseries'].shape}")
        print(f"Test: {test_data['timeseries'].shape}")
        
        # Initialize model and trainer
        print("\nInitializing model...")
        model = FloodDetectionModel(config=config)
        
        print("Setting up trainer...")
        trainer = Trainer(config, model, train_data, valid_data)
        
        # Train model
        print("\nStarting training...")
        trainer.train()
        
        # Save model
        print("\nSaving model...")
        with open(config.base_path / 'flood_detection_model.pkl', 'wb') as f:
            pickle.dump(trainer.state, f)
        
        print("Training complete!")
        


SyntaxError: incomplete input (2070608978.py, line 376)

In [None]:
class TrainingVisualizer:
    """Visualize training progress and results."""
    def __init__(self):
        self.metrics_history = {
            'train_loss': [], 'train_acc': [],
            'valid_loss': [], 'valid_acc': [],
            'valid_auroc': []
        }
        
    def update_metrics(self, train_metrics, valid_metrics):
        """Update metrics history."""
        self.metrics_history['train_loss'].append(np.mean([m['loss'] for m in train_metrics]))
        self.metrics_history['train_acc'].append(np.mean([m['accuracy'] for m in train_metrics]))
        self.metrics_history['valid_loss'].append(np.mean([m['focal_loss'] for m in valid_metrics]))
        self.metrics_history['valid_acc'].append(np.mean([m['accuracy'] for m in valid_metrics]))
        self.metrics_history['valid_auroc'].append(np.mean([m['auroc'] for m in valid_metrics]))
    
    def plot_metrics(self):
        """Plot training and validation metrics."""
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
        
        # Plot losses
        ax1.plot(self.metrics_history['train_loss'], label='Train Loss')
        ax1.plot(self.metrics_history['valid_loss'], label='Valid Loss')
        ax1.set_title('Loss Over Time')
        ax1.set_xlabel('Epoch')
        ax1.set_ylabel('Loss')
        ax1.legend()
        ax1.grid(True)
        
        # Plot accuracies
        ax2.plot(self.metrics_history['train_acc'], label='Train Acc')
        ax2.plot(self.metrics_history['valid_acc'], label='Valid Acc')
        ax2.plot(self.metrics_history['valid_auroc'], label='Valid AUROC')
        ax2.set_title('Metrics Over Time')
        ax2.set_xlabel('Epoch')
        ax2.set_ylabel('Score')
        ax2.legend()
        ax2.grid(True)
        
        plt.tight_layout()
        plt.show()
    
    def plot_prediction_examples(self, dataset, model, num_examples=5):
        """Plot example predictions."""
        rng = np.random.default_rng(42)
        indices = rng.choice(len(dataset), num_examples)
        
        fig, axes = plt.subplots(num_examples, 2, figsize=(12, 4*num_examples))
        
        for i, idx in enumerate(indices):
            # Get sample
            (ts, img), label = dataset.get_batch([idx])
            
            # Make prediction
            logits = model.apply({'params': model.params}, (ts, img))
            pred = jax.nn.sigmoid(logits)
            
            # Plot time series
            axes[i, 0].plot(ts[0, :, 0], label='Precipitation')
            if label is not None:
                axes[i, 0].plot(label[0], label='True Flood', alpha=0.5)
            axes[i, 0].plot(pred[0], label='Predicted Prob', alpha=0.5)
            axes[i, 0].set_title(f'Time Series - Sample {i+1}')
            axes[i, 0].legend()
            
            # Plot satellite image (RGB composite)
            rgb_img = img[0, :, :, [2,1,0]]  # Use bands 4,3,2 for RGB
            rgb_img = (rgb_img - rgb_img.min()) / (rgb_img.max() - rgb_img.min())
            axes[i, 1].imshow(rgb_img)
            axes[i, 1].set_title(f'Satellite Image - Sample {i+1}')
        
        plt.tight_layout()
        plt.show()

# Define the base path
BASE_PATH = r'C:\Users\CraigParker\OneDrive - Wits Health Consortium\PHR PC\Downloads\Joburg\Zindi_comp'

def load_and_preprocess_data():
    """Load and preprocess the data."""
    try:
        # Load CSV files
        data = pd.read_csv(os.path.join(BASE_PATH, 'Train.csv'))
        data_test = pd.read_csv(os.path.join(BASE_PATH, 'Test.csv'))
        
        # Debug: Print initial data info
        print("\nInitial data shapes:")
        print(f"Training data: {data.shape}")
        print(f"Test data: {data_test.shape}")
        print("\nTraining columns:", data.columns.tolist())
        
        # Process event IDs
        for df in [data, data_test]:
            # Clean event IDs
            df['event_id'] = df['event_id'].apply(lambda x: '_'.join(x.split('_')[:2]))
            # Create event_t column (time steps within each event)
            df['event_t'] = df.groupby('event_id').cumcount()
        
        # Verify created columns
        print("\nAfter preprocessing:")
        print("Training data columns:", data.columns.tolist())
        print("Sample event_t values:", data['event_t'].head())
        print("\nUnique events:", len(data['event_id'].unique()))
        print("Max time steps:", data['event_t'].max())
        
        # Verify event_t creation
        if 'event_t' not in data.columns:
            raise KeyError("Failed to create 'event_t' column in training data")
        if 'event_t' not in data_test.columns:
            raise KeyError("Failed to create 'event_t' column in test data")
        
        # Verify data integrity
        for df, name in [(data, 'training'), (data_test, 'test')]:
            # Check for required columns
            required_cols = ['event_id', 'event_t', 'precipitation']
            missing_cols = set(required_cols) - set(df.columns)
            if missing_cols:
                raise KeyError(f"Missing columns in {name} data: {missing_cols}")
            
            # Check for null values
            null_counts = df[required_cols].isnull().sum()
            if null_counts.any():
                print(f"\nWarning: Found null values in {name} data:")
                print(null_counts[null_counts > 0])
        
        return data, data_test
        
    except Exception as e:
        print(f"\nError in data preprocessing:")
        print(f"Type: {type(e).__name__}")
        print(f"Message: {str(e)}")
        raise

def validate_dataset(data: pd.DataFrame, name: str = "dataset"):
    """Validate dataset integrity and format."""
    print(f"\nValidating {name}...")
    
    # Check basic properties
    print(f"Shape: {data.shape}")
    print(f"Columns: {data.columns.tolist()}")
    
    # Check event structure
    num_events = len(data['event_id'].unique())
    timesteps_per_event = data.groupby('event_id').size()
    print(f"\nNumber of unique events: {num_events}")
    print(f"Timesteps per event:")
    print(f"  Min: {timesteps_per_event.min()}")
    print(f"  Max: {timesteps_per_event.max()}")
    print(f"  Mean: {timesteps_per_event.mean():.2f}")
    
    # Check data types
    print("\nData types:")
    print(data.dtypes)
    
    # Check for missing values
    missing = data.isnull().sum()
    if missing.any():
        print("\nMissing values:")
        print(missing[missing > 0])
    
    # Check precipitation values
    if 'precipitation' in data.columns:
        precip = data['precipitation']
        print("\nPrecipitation statistics:")
        print(f"  Min: {precip.min():.2f}")
        print(f"  Max: {precip.max():.2f}")
        print(f"  Mean: {precip.mean():.2f}")
        print(f"  Std: {precip.std():.2f}")
    
    # Check label distribution if present
    if 'label' in data.columns:
        label_dist = data['label'].value_counts(normalize=True)
        print("\nLabel distribution:")
        print(label_dist)
    
    print("\nValidation complete.")
    return True

def main():
    """Main execution function."""
    print("Verifying data files...")
    verify_data_files()
    
    print("Loading data...")
    data, data_test = load_and_preprocess_data()
    
    # Load image data
    print("Loading image data...")
    images_path = os.path.join(BASE_PATH, 'composite_images.npz')
    images = np.load(images_path)
    
    # Create train/validation split
    print("Creating dataset splits...")
    rng = np.random.default_rng(seed=42)
    event_ids = data['event_id'].unique()
    validation_size = int(len(event_ids) * 0.1)
    valid_ids = rng.choice(event_ids, size=validation_size, replace=False)
    
    # Create datasets
    train_mask = ~data['event_id'].isin(valid_ids)
    valid_mask = data['event_id'].isin(valid_ids)
    
    # Initialize datasets
    train_dataset = FloodDataset(
        timeseries=data[train_mask].pivot(index='event_id', columns='event_t', values='precipitation').to_numpy(),
        images=np.stack([images[id] for id in data[train_mask]['event_id'].unique()]),
        labels=data[train_mask].pivot(index='event_id', columns='event_t', values='label').to_numpy(),
        is_training=True
    )
    
    valid_dataset = FloodDataset(
        timeseries=data[valid_mask].pivot(index='event_id', columns='event_t', values='precipitation').to_numpy(),
        images=np.stack([images[id] for id in data[valid_mask]['event_id'].unique()]),
        labels=data[valid_mask].pivot(index='event_id', columns='event_t', values='label').to_numpy(),
        is_training=False
    )
    
    # Initialize model
    print("Initializing model...")
    model_config = {
        'patch_size': 16,
        'hidden_dim': 512,
        'num_layers': 8,
        'num_heads': 8,
        'dropout': 0.1,
        'input_shape': (32, 128, 128, 6)  # Add input shape
    }
    
    model = SatelliteViT(**model_config)
    
    # Initialize trainer
    print("Setting up trainer...")
    optimizer_config = {
        'learning_rate': 1e-4,
        'weight_decay': 0.01
    }
    
    trainer = Trainer(
        model=model,
        optimizer_config=optimizer_config,
        train_dataset=train_dataset,
        valid_dataset=valid_dataset,
        num_epochs=100,
        batch_size=32,
        steps_per_eval=5
    )
    
    # Initialize visualizer
    visualizer = TrainingVisualizer()
    
    # Training loop
    print("Starting training...")
    try:
        trainer.train(visualizer)
    except KeyboardInterrupt:
        print("\nTraining interrupted by user")
    
    # Plot final results
    print("Plotting results...")
    visualizer.plot_metrics()
    visualizer.plot_prediction_examples(valid_dataset, model)
    
    # Save model
    print("Saving model...")
    model_path = os.path.join(BASE_PATH, 'flood_detection_model.pkl')
    with open(model_path, 'wb') as f:
        pickle.dump(trainer.state, f)
    
    print("Training complete!")

if __name__ == "__main__":
    main()

Verifying data files...
Loading data...
Training data columns: Index(['event_id', 'precipitation', 'label'], dtype='object')
Test data columns: Index(['event_id', 'precipitation'], dtype='object')


KeyError: "'event_t' column not found in training data"

In [23]:
"""
Simplified Flood Detection Model for South Africa

This script:
1) Verifies that all required data files exist.
2) Loads and preprocesses the training and test data.
3) Builds a simplified Transformer-based model with Flax/JAX.
4) Trains the model on the training set and evaluates on a validation set.
5) Saves the final model state.

Dependencies (install if needed):
  pip install jax jaxlib flax optax
  pip install numpy pandas matplotlib tqdm
"""

import os
import math
import pickle
from pathlib import Path
from typing import Dict, Tuple
from functools import partial

import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from tqdm.auto import tqdm

import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state
import optax


# -----------------------------------------------------------------------------
# 1. Configuration
# -----------------------------------------------------------------------------
class Config:
    """Holds paths and hyperparameters for the project."""
    def __init__(self):
        # Paths
        self.BASE_PATH = Path(
            r"C:\Users\CraigParker\OneDrive - Wits Health Consortium\PHR PC\Downloads\Joburg\Zindi_comp"
        )
        self.REQUIRED_FILES = ["Train.csv", "Test.csv", "composite_images.npz", "SampleSubmission.csv"]

        # Random seed
        self.SEED = 42

        # Dataset
        self.VALID_SPLIT = 0.1  # 10% validation split
        self.BATCH_SIZE = 32

        # Model
        self.TIME_SERIES_HIDDEN_DIM = 128
        self.IMG_HIDDEN_DIM = 128
        self.NUM_LAYERS = 4
        self.NUM_HEADS = 4
        self.DROPOUT = 0.1

        # Training
        self.NUM_EPOCHS = 5
        self.LEARNING_RATE = 1e-4
        self.WEIGHT_DECAY = 0.01


# -----------------------------------------------------------------------------
# 2. Verify Required Data Files
# -----------------------------------------------------------------------------
def verify_data_files(config: Config) -> None:
    """Check if all required data files exist in the specified path."""
    missing_files = []
    for file_name in config.REQUIRED_FILES:
        file_path = config.BASE_PATH / file_name
        if not file_path.exists():
            missing_files.append(file_name)

    if missing_files:
        raise FileNotFoundError(
            f"Missing files in {config.BASE_PATH}: {', '.join(missing_files)}"
        )


# -----------------------------------------------------------------------------
# 3. Data Loading and Preprocessing
# -----------------------------------------------------------------------------
class DataProcessor:
    """Loads and preprocesses CSV files and image data."""

    def __init__(self, config: Config):
        self.config = config
        self.rng = np.random.default_rng(config.SEED)

    def load_raw_data(self) -> Tuple[pd.DataFrame, pd.DataFrame, Dict[str, np.ndarray]]:
        """Loads the CSV files and the composite images dictionary."""
        # Load train/test CSV
        train_df = pd.read_csv(self.config.BASE_PATH / "Train.csv")
        test_df = pd.read_csv(self.config.BASE_PATH / "Test.csv")

        # Clean event_id (remove trailing "_XXX")
        for df in [train_df, test_df]:
            df["event_id"] = df["event_id"].apply(lambda x: "_".join(x.split("_")[:2]))
            # Create a time-step column for each event
            df["event_t"] = df.groupby("event_id").cumcount()

        # Load images (npz file = dictionary of event_id -> 3D array)
        images_dict = dict(np.load(self.config.BASE_PATH / "composite_images.npz", allow_pickle=True))

        return train_df, test_df, images_dict

    def prepare_datasets(self) -> Tuple[Dict, Dict, Dict]:
        """Split the training data into train/val sets, and set up test data."""
        train_df, test_df, images_dict = self.load_raw_data()

        # Split event IDs for train/validation
        train_events, valid_events = train_test_split(
            train_df["event_id"].unique(),
            test_size=self.config.VALID_SPLIT,
            random_state=self.config.SEED
        )

        # A helper to pivot precipitation into timeseries shape
        def pivot_timeseries(df: pd.DataFrame, events: np.ndarray) -> np.ndarray:
            subset = df[df["event_id"].isin(events)]
            ts_matrix = subset.pivot(
                index="event_id", columns="event_t", values="precipitation"
            ).fillna(0).to_numpy()
            return ts_matrix

        # A helper to get stacked images for each event
        def gather_images(events: np.ndarray) -> np.ndarray:
            return np.stack([images_dict[e] for e in events])

        # A helper to pivot labels if they exist
        def pivot_labels(df: pd.DataFrame, events: np.ndarray) -> np.ndarray:
            if "label" not in df.columns:
                return None
            subset = df[df["event_id"].isin(events)]
            label_matrix = subset.pivot(
                index="event_id", columns="event_t", values="label"
            ).fillna(0).to_numpy()
            return label_matrix

        # Build dictionary-based data splits
        train_data = {
            "timeseries": pivot_timeseries(train_df, train_events),
            "images": gather_images(train_events),
            "labels": pivot_labels(train_df, train_events),
        }
        valid_data = {
            "timeseries": pivot_timeseries(train_df, valid_events),
            "images": gather_images(valid_events),
            "labels": pivot_labels(train_df, valid_events),
        }
        test_events = test_df["event_id"].unique()
        test_data = {
            "timeseries": pivot_timeseries(test_df, test_events),
            "images": gather_images(test_events),
            "labels": pivot_labels(test_df, test_events),  # Usually None
        }

        return train_data, valid_data, test_data


# -----------------------------------------------------------------------------
# 4. Model Definition
# -----------------------------------------------------------------------------
class SimpleFloodModel(nn.Module):
    """
    A simplified model that:
      - Encodes time-series data with a small Transformer.
      - Encodes images by flattening them and passing through a small Transformer.
      - Merges both features and outputs a per-time-step prediction.
    """

    config: Config

    @nn.compact
    def __call__(self, timeseries: jnp.ndarray, images: jnp.ndarray, train: bool = True):
        """
        Args:
          timeseries: (batch_size, T, ?) numeric data (precip, etc.)
          images:     (batch_size, H, W, C) satellite images
        Returns:
          (batch_size, T) float predictions
        """
        # 1) Time-series encoder
        # Project the timeseries data
        x_ts = nn.Dense(self.config.TIME_SERIES_HIDDEN_DIM)(timeseries.reshape(timeseries.shape[0], -1))

        # Apply a few transformer layers on time-series
        for _ in range(self.config.NUM_LAYERS):
            # Normalization
            y = nn.LayerNorm()(x_ts)
            # Self-attention
            y = nn.MultiHeadDotProductAttention(num_heads=self.config.NUM_HEADS)(
                y, y, deterministic=not train
            )
            x_ts = x_ts + y  # residual

        # 2) Image encoder (extremely simplified)
        # Flatten images: shape (B, H*W*C)
        b, h, w, c = images.shape
        x_img = images.reshape(b, h * w * c)
        x_img = nn.Dense(self.config.IMG_HIDDEN_DIM)(x_img)

        # Just do one transformer layer on the flattened image
        x_img = jnp.expand_dims(x_img, axis=1)  # shape (B, 1, hidden_dim)
        for _ in range(1):  # let's do 1 layer only
            y = nn.LayerNorm()(x_img)
            y = nn.MultiHeadDotProductAttention(num_heads=1)(y, y, deterministic=not train)
            x_img = x_img + y

        # Expand so we can concatenate with the time-series dimension
        # We'll replicate the single image feature across T steps
        T = timeseries.shape[1]
        x_img_tiled = jnp.tile(x_img, (1, T, 1))

        # 3) Combine time-series + image encodings
        combined = jnp.concatenate([x_ts, x_img_tiled], axis=-1)

        # Project down to (batch_size, T, 1)
        logits = nn.Dense(1)(combined)
        return jnp.squeeze(logits, axis=-1)  # shape (B, T)


# -----------------------------------------------------------------------------
# 5. Trainer Definition
# -----------------------------------------------------------------------------
class Trainer:
    def __init__(self, config: Config, model: nn.Module, train_data: Dict, valid_data: Dict):
        self.config = config
        self.model = model
        self.train_data = train_data
        self.valid_data = valid_data

        # Create a training state (params, optimizer)
        rng = jax.random.PRNGKey(config.SEED)

        # Dummy batch to initialize model
        dummy_ts = jnp.ones((1, 10, 1), dtype=jnp.float32)
        dummy_img = jnp.ones((1, 8, 8, 6), dtype=jnp.float32)  # Dummy shape
        variables = model.init(rng, dummy_ts, dummy_img)

        # Create optimizer schedule (e.g., simple constant or AdamW)
        tx = optax.adamw(learning_rate=config.LEARNING_RATE, weight_decay=config.WEIGHT_DECAY)

        self.state = train_state.TrainState.create(
            apply_fn=model.apply, params=variables["params"], tx=tx
        )

    def training_loop(self):
        """Run a simple training loop."""
        # Convert data to jax arrays once
        train_ts = jnp.array(self.train_data["timeseries"], dtype=jnp.float32)
        train_img = jnp.array(self.train_data["images"], dtype=jnp.float32)
        train_label = (
            jnp.array(self.train_data["labels"], dtype=jnp.float32)
            if self.train_data["labels"] is not None
            else None
        )

        num_samples = train_ts.shape[0]
        num_steps = (num_samples // self.config.BATCH_SIZE) or 1

        for epoch in range(self.config.NUM_EPOCHS):
            # Shuffle data
            perm = self._rng_permutation(num_samples)
            train_ts_shuffled = train_ts[perm]
            train_img_shuffled = train_img[perm]
            train_label_shuffled = train_label[perm] if train_label is not None else None

            epoch_losses = []

            with tqdm(range(num_steps), desc=f"Epoch {epoch+1}/{self.config.NUM_EPOCHS}") as pbar:
                for step_idx in pbar:
                    # Mini-batch slice
                    start = step_idx * self.config.BATCH_SIZE
                    end = start + self.config.BATCH_SIZE

                    batch_ts = train_ts_shuffled[start:end]
                    batch_img = train_img_shuffled[start:end]
                    batch_label = train_label_shuffled[start:end] if train_label_shuffled is not None else None

                    # Training step
                    self.state, loss_val = self._train_step(self.state, batch_ts, batch_img, batch_label)
                    epoch_losses.append(float(loss_val))
                    pbar.set_postfix({"loss": f"{np.mean(epoch_losses):.4f}"})

            # After each epoch, do a quick validation if labels exist
            if self.valid_data["labels"] is not None:
                valid_metrics = self.evaluate(self.valid_data)
                print(f"Validation after epoch {epoch+1}: loss={valid_metrics['loss']:.4f}")

        print("Training completed.")

    @partial(jax.jit, static_argnums=(0,))
    def _train_step(self, state, timeseries, images, labels):
        """Single training step."""
        def loss_fn(params):
            logits = self.model.apply({"params": params}, timeseries, images, train=True)
            if labels is None:
                # If no labels, no supervised loss
                return 0.0, logits
            # Simple binary cross-entropy
            loss = optax.sigmoid_binary_cross_entropy(logits, labels).mean()
            return loss, logits

        (loss, _), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
        new_state = state.apply_gradients(grads=grads)
        return new_state, loss

    def evaluate(self, data_split: Dict):
        """Compute average loss on a given dataset split."""
        timeseries = jnp.array(data_split["timeseries"], dtype=jnp.float32)
        images = jnp.array(data_split["images"], dtype=jnp.float32)
        labels = data_split["labels"]
        if labels is None:
            return {"loss": np.nan}
        labels = jnp.array(labels, dtype=jnp.float32)

        logits = self.model.apply({"params": self.state.params}, timeseries, images, train=False)
        loss = optax.sigmoid_binary_cross_entropy(logits, labels).mean()
        return {"loss": float(loss)}

    def _rng_permutation(self, n):
        """Return a random permutation of indices [0..n-1], for data shuffling."""
        # We can do CPU-based permutation in NumPy, then convert
        indices = np.arange(n)
        self.config.SEED += 1  # increment seed each time for variety
        np.random.default_rng(self.config.SEED).shuffle(indices)
        return jnp.array(indices)


# -----------------------------------------------------------------------------
# 6. Main Script
# -----------------------------------------------------------------------------
def main():
    config = Config()

    # Verify required files
    verify_data_files(config)

    # Load and prepare datasets
    processor = DataProcessor(config)
    train_data, valid_data, test_data = processor.prepare_datasets()

    # Initialize model and trainer
    model = SimpleFloodModel(config=config)
    trainer = Trainer(config, model, train_data, valid_data)

    # Train model
    trainer.training_loop()

    # Evaluate on validation set
    if valid_data["labels"] is not None:
        val_metrics = trainer.evaluate(valid_data)
        print(f"Final validation loss: {val_metrics['loss']:.4f}")

    # Save model parameters
    with open(config.BASE_PATH / "flood_detection_model.pkl", "wb") as f:
        pickle.dump(trainer.state, f)
    print("Model saved successfully.")

    # We won't implement predictions on test_data here,
    # but you could follow the same approach used in `evaluate()`.


if __name__ == "__main__":
    main()


TypeError: Cannot concatenate arrays with different numbers of dimensions: got (1, 128), (1, 10, 128).