## 3. Optimized Transformer Model

### 3.1 Model Architecture

In [None]:
class OptimizedFireTransformer(nn.Module):
    """Optimized transformer for multi-area fire detection"""
    
    def __init__(self, input_dim=6, seq_len=60, d_model=128, num_heads=4, 
                 num_layers=3, num_classes=3, num_areas=5, dropout=0.1):
        super().__init__()
        
        self.input_proj = nn.Linear(input_dim, d_model)
        self.area_embedding = nn.Embedding(num_areas, d_model)
        self.pos_encoding = nn.Parameter(torch.randn(seq_len, d_model))
        
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=num_heads, dim_feedforward=d_model*4,
            dropout=dropout, batch_first=True, activation='gelu'
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
        
        self.fire_classifier = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Dropout(dropout),
            nn.Linear(d_model, d_model//2),
            nn.GELU(),
            nn.Linear(d_model//2, num_classes)
        )
        
        self.risk_predictor = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Linear(d_model, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x, area_types):
        batch_size, seq_len, _ = x.shape
        
        x = self.input_proj(x)
        area_emb = self.area_embedding(area_types).unsqueeze(1).expand(-1, seq_len, -1)
        x = x + area_emb + self.pos_encoding[:seq_len].unsqueeze(0)
        
        x = self.transformer(x)
        x = x.mean(dim=1)  # Global pooling
        
        return {
            'fire_logits': self.fire_classifier(x),
            'risk_score': self.risk_predictor(x) * 100.0
        }

### 3.2 Model Parameter Visualization

In [None]:
def visualize_model_parameters(model):
    """Visualize model parameters and architecture"""
    
    # Count parameters
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    # Get parameter sizes by layer
    layer_params = {}
    for name, param in model.named_parameters():
        layer_name = name.split('.')[0]
        if layer_name not in layer_params:
            layer_params[layer_name] = 0
        layer_params[layer_name] += param.numel()
    
    # Create figure
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
    
    # Parameter count pie chart
    labels = list(layer_params.keys())
    sizes = list(layer_params.values())
    
    ax1.pie(sizes, labels=labels, autopct='%1.1f%%', startangle=90)
    ax1.axis('equal')
    ax1.set_title('Parameter Distribution by Layer')
    
    # Parameter count bar chart
    y_pos = np.arange(len(labels))
    ax2.barh(y_pos, sizes)
    ax2.set_yticks(y_pos)
    ax2.set_yticklabels(labels)
    ax2.set_xlabel('Number of Parameters')
    ax2.set_title('Parameter Count by Layer')
    
    # Add parameter count labels
    for i, v in enumerate(sizes):
        ax2.text(v + 0.1, i, f"{v:,}", va='center')
    
    plt.tight_layout()
    
    # Save figure if enabled
    if VISUALIZATION_CONFIG['save_figures']:
        plt.savefig(f"{VISUALIZATION_CONFIG['figure_dir']}/model_parameters.png", dpi=300)
    
    plt.show()
    
    # Print summary
    print(f"📊 Model Parameter Summary:")
    print(f"   Total parameters: {total_params:,}")
    print(f"   Trainable parameters: {trainable_params:,}")
    print(f"   Non-trainable parameters: {total_params - trainable_params:,}")
    
    # Compare with original model
    original_params = 256 * 6 * 4 * 256  # Rough estimate
    reduction = 1 - (total_params / original_params)
    print(f"   Parameter reduction: {reduction:.1%} from original model")

## 4. Training Dashboard and Visualizations

### 4.1 Real-time Training Dashboard

In [None]:
def create_training_dashboard():
    """Create a real-time training progress dashboard"""
    
    # Initialize dashboard data
    dashboard_data = {
        'epochs': [],
        'train_loss': [],
        'val_accuracy': [],
        'learning_rate': [],
        'time_elapsed': [],
        'memory_usage': []
    }
    
    # Create initial plots
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    fig.suptitle('Fire Detection AI - Training Progress Dashboard', fontsize=16)
    
    # Loss plot
    loss_line, = axes[0, 0].plot([], [], 'b-', label='Training Loss')
    axes[0, 0].set_title('Training Loss')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].grid(True)
    
    # Accuracy plot
    acc_line, = axes[0, 1].plot([], [], 'g-', label='Validation Accuracy')
    axes[0, 1].set_title('Validation Accuracy')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Accuracy')
    axes[0, 1].grid(True)
    
    # Learning rate plot
    lr_line, = axes[1, 0].plot([], [], 'r-', label='Learning Rate')
    axes[1, 0].set_title('Learning Rate')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Learning Rate')
    axes[1, 0].grid(True)
    
    # Memory usage plot
    mem_line, = axes[1, 1].plot([], [], 'm-', label='GPU Memory Usage (GB)')
    axes[1, 1].set_title('GPU Memory Usage')
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('Memory (GB)')
    axes[1, 1].grid(True)
    
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    
    return fig, axes, dashboard_data

def update_dashboard(fig, axes, dashboard_data, epoch, loss, accuracy, lr, time_elapsed, memory_usage):
    """Update the training dashboard with new data"""
    
    # Update data
    dashboard_data['epochs'].append(epoch)
    dashboard_data['train_loss'].append(loss)
    dashboard_data['val_accuracy'].append(accuracy)
    dashboard_data['learning_rate'].append(lr)
    dashboard_data['time_elapsed'].append(time_elapsed)
    dashboard_data['memory_usage'].append(memory_usage)
    
    # Update plots
    axes[0, 0].plot(dashboard_data['epochs'], dashboard_data['train_loss'], 'b-')
    axes[0, 1].plot(dashboard_data['epochs'], dashboard_data['val_accuracy'], 'g-')
    axes[1, 0].plot(dashboard_data['epochs'], dashboard_data['learning_rate'], 'r-')
    axes[1, 1].plot(dashboard_data['epochs'], dashboard_data['memory_usage'], 'm-')
    
    # Update limits
    for i in range(2):
        for j in range(2):
            axes[i, j].relim()
            axes[i, j].autoscale_view()
    
    # Add current values as text
    plt.figtext(0.5, 0.01, f"Epoch: {epoch} | Loss: {loss:.4f} | Accuracy: {accuracy:.4f} | Time: {time_elapsed:.1f}s", 
                ha="center", fontsize=12, bbox={"facecolor":"orange", "alpha":0.2, "pad":5})
    
    # Refresh the figure
    fig.canvas.draw()
    display.clear_output(wait=True)
    display.display(fig)
    
    # Save figure if enabled
    if VISUALIZATION_CONFIG['save_figures']:
        plt.savefig(f"{VISUALIZATION_CONFIG['figure_dir']}/training_dashboard_epoch_{epoch}.png", dpi=300)

### 4.2 Memory Usage Monitoring

In [None]:
def monitor_memory_usage():
    """Monitor and visualize memory usage during training"""
    
    # Initialize memory tracking
    memory_usage = []
    timestamps = []
    
    if torch.cuda.is_available():
        # Get initial memory usage
        torch.cuda.reset_peak_memory_stats()
        initial_memory = torch.cuda.memory_allocated() / 1e9  # GB
        
        memory_usage.append(initial_memory)
        timestamps.append(0)
        
        def track_memory():
            current_memory = torch.cuda.memory_allocated() / 1e9  # GB
            peak_memory = torch.cuda.max_memory_allocated() / 1e9  # GB
            memory_usage.append(current_memory)
            timestamps.append(time.time() - start_time)
            
            return current_memory, peak_memory
    else:
        # CPU memory tracking (psutil)
        try:
            import psutil
            process = psutil.Process(os.getpid())
            initial_memory = process.memory_info().rss / 1e9  # GB
            
            memory_usage.append(initial_memory)
            timestamps.append(0)
            
            def track_memory():
                current_memory = process.memory_info().rss / 1e9  # GB
                memory_usage.append(current_memory)
                timestamps.append(time.time() - start_time)
                
                return current_memory, max(memory_usage)
        except ImportError:
            # Fallback if psutil not available
            def track_memory():
                return 0, 0
    
    return track_memory, memory_usage, timestamps

def visualize_memory_usage(memory_usage, timestamps):
    """Visualize memory usage during training"""
    
    plt.figure(figsize=(12, 6))
    plt.plot(timestamps, memory_usage, 'b-')
    plt.title('Memory Usage During Training')
    plt.xlabel('Time (seconds)')
    plt.ylabel('Memory Usage (GB)')
    plt.grid(True)
    
    # Add peak memory as text
    peak_memory = max(memory_usage)
    plt.axhline(y=peak_memory, color='r', linestyle='--', label=f'Peak: {peak_memory:.2f} GB')
    plt.legend()
    
    # Save figure if enabled
    if VISUALIZATION_CONFIG['save_figures']:
        plt.savefig(f"{VISUALIZATION_CONFIG['figure_dir']}/memory_usage.png", dpi=300)
    
    plt.show()

### 4.3 Initialize Model and Visualization

In [None]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"🚀 Device: {device}")
if torch.cuda.is_available():
    logger.info(f"   GPU: {torch.cuda.get_device_name(0)}")
    logger.info(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

# Initialize model
model = OptimizedFireTransformer(
    input_dim=X_train.shape[2],
    seq_len=X_train.shape[1],
    d_model=TRANSFORMER_CONFIG['d_model'],
    num_heads=TRANSFORMER_CONFIG['num_heads'],
    num_layers=TRANSFORMER_CONFIG['num_layers'],
    num_classes=len(np.unique(y_train)),
    num_areas=len(np.unique(areas_train)),
    dropout=TRANSFORMER_CONFIG['dropout']
).to(device)

# Visualize model parameters
visualize_model_parameters(model)

# Initialize training dashboard
dashboard_fig, dashboard_axes, dashboard_data = create_training_dashboard()

# Initialize memory tracking
track_memory, memory_usage, timestamps = monitor_memory_usage()