In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy.io
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, random_split
import os
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

class BatteryLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size, dropout_rate=0.2):
        super(BatteryLSTM, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        # LSTM layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers,
                           batch_first=True, dropout=dropout_rate)

        # Attention mechanism
        self.attention = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, 1)
        )

        # Fully connected layers
        self.fc_layers = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_size // 2, hidden_size // 4),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_size // 4, output_size)
        )

    def forward(self, x):
        # LSTM forward pass
        lstm_out, (hidden, cell) = self.lstm(x)

        # Attention mechanism
        attention_weights = torch.softmax(self.attention(lstm_out), dim=1)
        context_vector = torch.sum(attention_weights * lstm_out, dim=1)

        # Fully connected layers
        output = self.fc_layers(context_vector)
        return output

class BatteryDataProcessor:
    def __init__(self, sequence_length=10):
        self.sequence_length = sequence_length
        self.scaler = StandardScaler()
        self.capacity_scaler = StandardScaler()

    def load_battery_data(self, file_path):
        """Load MATLAB battery data file"""
        try:
            data = scipy.io.loadmat(file_path)
            return data
        except Exception as e:
            print(f"Error loading file {file_path}: {e}")
            return None

    def explore_mat_structure(self, data, file_name):
        """Explore the structure of .mat file"""
        print(f"\n=== Exploring {file_name} ===")
        for key in data.keys():
            if not key.startswith('__'):
                value = data[key]
                print(f"Key: {key}")
                print(f"  Type: {type(value)}")
                if hasattr(value, 'shape'):
                    print(f"  Shape: {value.shape}")
                if hasattr(value, 'dtype'):
                    print(f"  Dtype: {value.dtype}")

                # If it's a structured array, show the fields
                if hasattr(value, 'dtype') and value.dtype.names:
                    print(f"  Fields: {value.dtype.names}")
                    # Show first element structure if available
                    if len(value) > 0:
                        first_elem = value[0,0] if value.ndim == 2 else value[0]
                        if hasattr(first_elem, 'dtype') and first_elem.dtype.names:
                            print(f"  First element fields: {first_elem.dtype.names}")

    def extract_features_from_arc_data(self, data, battery_id):
        """Extract features from NASA ARC battery dataset structure"""
        cycles = []

        # NASA ARC dataset typically has 'cycle' as the main structure
        if 'cycle' in data:
            cycle_data = data['cycle']
            print(f"Processing {battery_id} - Cycles: {cycle_data.shape}")

            for cycle_num in range(cycle_data.shape[1]):
                cycle = cycle_data[0, cycle_num]

                # Check if cycle has the expected structure
                if hasattr(cycle, 'dtype') and cycle.dtype.names:
                    if 'data' in cycle.dtype.names and 'type' in cycle.dtype.names:
                        cycle_type = cycle['type'][0]
                        # Only process discharge cycles
                        if cycle_type == 'discharge':
                            cycle_info = cycle['data'][0, 0]

                            if hasattr(cycle_info, 'dtype') and cycle_info.dtype.names:
                                # Extract available measurements
                                features = {'cycle': cycle_num + 1, 'battery_id': battery_id}

                                # Common measurements in NASA dataset
                                possible_fields = [
                                    'Voltage_measured', 'Current_measured',
                                    'Temperature_measured', 'Capacity',
                                    'Current_load', 'Voltage_load', 'Time'
                                ]

                                for field in possible_fields:
                                    if field in cycle_info.dtype.names:
                                        field_data = cycle_info[field][0, 0]
                                        if field_data.size > 0:
                                            if field == 'Capacity':
                                                features[field.lower()] = float(field_data[0, 0])
                                            else:
                                                # Calculate statistics for time-series data
                                                features[f'{field.lower()}_mean'] = float(np.mean(field_data))
                                                features[f'{field.lower()}_std'] = float(np.std(field_data))
                                                features[f'{field.lower()}_max'] = float(np.max(field_data))
                                                features[f'{field.lower()}_min'] = float(np.min(field_data))

                                # Only add if we have capacity data
                                if 'capacity' in features:
                                    cycles.append(features)

        return pd.DataFrame(cycles)

    def extract_features_simple(self, data, battery_id):
        """Simplified feature extraction for different .mat structures"""
        cycles = []

        # Try to find capacity data directly
        for key in data.keys():
            if not key.startswith('__'):
                value = data[key]
                print(f"Checking key {key} with shape {value.shape if hasattr(value, 'shape') else 'N/A'}")

                # Look for capacity-like data
                if 'capacity' in key.lower() and hasattr(value, 'shape'):
                    if value.size > 1:
                        # Assume it's a time series of capacities
                        for i, capacity in enumerate(value.flatten()):
                            cycles.append({
                                'cycle': i + 1,
                                'battery_id': battery_id,
                                'capacity': float(capacity)
                            })
                    else:
                        cycles.append({
                            'cycle': 1,
                            'battery_id': battery_id,
                            'capacity': float(value)
                        })

        return pd.DataFrame(cycles)

    def load_all_battery_files(self, data_directory):
        """Load all battery .mat files from directory"""
        data_path = Path(data_directory)
        all_data = {}

        # Find all .mat files
        mat_files = list(data_path.glob('*.mat'))
        print(f"Found {len(mat_files)} .mat files:")

        for mat_file in mat_files:
            print(f"  - {mat_file.name}")
            battery_id = mat_file.stem
            data = self.load_battery_data(mat_file)

            if data is not None:
                # Explore the structure first
                self.explore_mat_structure(data, battery_id)

                # Try different extraction methods
                df = self.extract_features_from_arc_data(data, battery_id)

                if df.empty:
                    df = self.extract_features_simple(data, battery_id)

                if not df.empty:
                    all_data[battery_id] = df
                    print(f"  Extracted {len(df)} cycles")
                else:
                    print(f"  No cycles extracted from {battery_id}")

        return all_data

    def combine_battery_data(self, all_data):
        """Combine data from all batteries into single DataFrame"""
        combined_df = pd.DataFrame()

        for battery_id, df in all_data.items():
            if not df.empty:
                df['battery_id'] = battery_id
                combined_df = pd.concat([combined_df, df], ignore_index=True)

        print(f"Combined data shape: {combined_df.shape}")
        return combined_df

    def create_sequences(self, data, target_col='capacity'):
        """Create sequences for LSTM training"""
        sequences = []
        targets = []
        sequence_info = []  # Store information about each sequence

        # Group by battery_id and create sequences for each battery
        for battery_id in data['battery_id'].unique():
            battery_data = data[data['battery_id'] == battery_id].sort_values('cycle')

            # Select feature columns (exclude cycle and battery_id)
            feature_cols = [col for col in battery_data.columns
                           if col not in ['cycle', 'battery_id', target_col]
                           and pd.api.types.is_numeric_dtype(battery_data[col])]

            if len(feature_cols) == 0:
                # If no features, use cycle number as feature
                feature_cols = ['cycle']
                battery_data_features = battery_data[feature_cols].copy()
            else:
                battery_data_features = battery_data[feature_cols].copy()

            target_values = battery_data[target_col].values

            # Handle NaN values
            battery_data_features = battery_data_features.fillna(battery_data_features.mean())
            target_values = np.nan_to_num(target_values, nan=np.nanmean(target_values))

            # Scale features and targets
            if len(sequences) == 0:  # First battery, fit scalers
                scaled_features = self.scaler.fit_transform(battery_data_features)
                scaled_targets = self.capacity_scaler.fit_transform(target_values.reshape(-1, 1)).flatten()
            else:  # Subsequent batteries, transform only
                scaled_features = self.scaler.transform(battery_data_features)
                scaled_targets = self.capacity_scaler.transform(target_values.reshape(-1, 1)).flatten()

            # Create sequences for this battery
            for i in range(len(battery_data) - self.sequence_length):
                seq = scaled_features[i:(i + self.sequence_length)]
                target = scaled_targets[i + self.sequence_length]
                sequences.append(seq)
                targets.append(target)
                
                # Store sequence information for later reference
                sequence_info.append({
                    'battery_id': battery_id,
                    'actual_cycle': battery_data.iloc[i + self.sequence_length]['cycle'],
                    'actual_capacity': battery_data.iloc[i + self.sequence_length][target_col]
                })

        return np.array(sequences), np.array(targets), sequence_info

    def prepare_synthetic_data(self, num_cycles=500):
        """Generate synthetic battery data for demonstration"""
        cycles = []

        # Synthetic battery degradation pattern
        for i in range(num_cycles):
            # Capacity degradation (typical pattern)
            base_capacity = 2.0
            degradation = 0.001 * i + 0.0001 * i**2
            noise = np.random.normal(0, 0.01)
            capacity = base_capacity - degradation + noise

            # Other features with realistic patterns
            avg_voltage = 3.7 - 0.0005 * i + np.random.normal(0, 0.02)
            avg_current = 1.5 + np.random.normal(0, 0.1)
            max_voltage = 4.2 - 0.0002 * i
            min_voltage = 3.0 + 0.0001 * i
            avg_temperature = 25 + 0.01 * i + np.random.normal(0, 0.5)

            cycles.append({
                'cycle': i + 1,
                'capacity': max(0.5, capacity),  # Prevent negative capacity
                'avg_voltage': avg_voltage,
                'avg_current': avg_current,
                'max_voltage': max_voltage,
                'min_voltage': min_voltage,
                'avg_temperature': avg_temperature,
                'battery_id': 'synthetic'
            })

        return pd.DataFrame(cycles)

class BatteryTrainer:
    def __init__(self, model, device):
        self.model = model
        self.device = device
        self.train_losses = []
        self.val_losses = []

    def train(self, train_loader, val_loader, num_epochs, learning_rate):
        criterion = nn.MSELoss()
        optimizer = optim.Adam(self.model.parameters(), lr=learning_rate, weight_decay=1e-5)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=10, factor=0.5)

        for epoch in range(num_epochs):
            # Training phase
            self.model.train()
            train_loss = 0.0

            for batch_X, batch_y in train_loader:
                batch_X, batch_y = batch_X.to(self.device), batch_y.to(self.device)

                optimizer.zero_grad()
                outputs = self.model(batch_X)
                loss = criterion(outputs.squeeze(), batch_y)
                loss.backward()
                optimizer.step()

                train_loss += loss.item()

            # Validation phase
            self.model.eval()
            val_loss = 0.0

            with torch.no_grad():
                for batch_X, batch_y in val_loader:
                    batch_X, batch_y = batch_X.to(self.device), batch_y.to(self.device)
                    outputs = self.model(batch_X)
                    loss = criterion(outputs.squeeze(), batch_y)
                    val_loss += loss.item()

            avg_train_loss = train_loss / len(train_loader)
            avg_val_loss = val_loss / len(val_loader)

            self.train_losses.append(avg_train_loss)
            self.val_losses.append(avg_val_loss)

            scheduler.step(avg_val_loss)

            if (epoch + 1) % 20 == 0:
                print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss:.6f}, Val Loss: {avg_val_loss:.6f}')

    def plot_training_history(self):
        plt.figure(figsize=(10, 6))
        plt.plot(self.train_losses, label='Training Loss')
        plt.plot(self.val_losses, label='Validation Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.title('Training History')
        plt.legend()
        plt.grid(True)
        plt.show()

def evaluate_model(model, test_loader, data_processor, device, sequence_info_test=None):
    """Evaluate the model performance and return predictions with additional info"""
    model.eval()
    predictions = []
    actuals = []
    all_sequence_info = []

    with torch.no_grad():
        for batch_idx, (batch_X, batch_y) in enumerate(test_loader):
            batch_X, batch_y = batch_X.to(device), batch_y.to(device)
            outputs = model(batch_X)

            # Inverse transform predictions
            pred_capacity = data_processor.capacity_scaler.inverse_transform(
                outputs.cpu().numpy().reshape(-1, 1)
            ).flatten()

            actual_capacity = data_processor.capacity_scaler.inverse_transform(
                batch_y.cpu().numpy().reshape(-1, 1)
            ).flatten()

            predictions.extend(pred_capacity)
            actuals.extend(actual_capacity)
            
            # Store sequence information if available
            if sequence_info_test is not None:
                start_idx = batch_idx * test_loader.batch_size
                end_idx = start_idx + len(batch_X)
                all_sequence_info.extend(sequence_info_test[start_idx:end_idx])

    # Calculate metrics
    mse = mean_squared_error(actuals, predictions)
    mae = mean_absolute_error(actuals, predictions)
    rmse = np.sqrt(mse)
    r2 = r2_score(actuals, predictions)

    print(f"Evaluation Metrics:")
    print(f"MSE: {mse:.6f}")
    print(f"MAE: {mae:.6f}")
    print(f"RMSE: {rmse:.6f}")
    print(f"R² Score: {r2:.6f}")

    return predictions, actuals, {'mse': mse, 'mae': mae, 'rmse': rmse, 'r2': r2}, all_sequence_info

def get_all_predictions(model, data_loader, data_processor, device, sequence_info=None):
    """Get predictions for all data in the loader"""
    model.eval()
    all_predictions = []
    all_actuals = []
    all_info = []

    with torch.no_grad():
        for batch_idx, (batch_X, batch_y) in enumerate(data_loader):
            batch_X, batch_y = batch_X.to(device), batch_y.to(device)
            outputs = model(batch_X)

            # Inverse transform predictions
            pred_capacity = data_processor.capacity_scaler.inverse_transform(
                outputs.cpu().numpy().reshape(-1, 1)
            ).flatten()

            actual_capacity = data_processor.capacity_scaler.inverse_transform(
                batch_y.cpu().numpy().reshape(-1, 1)
            ).flatten()

            all_predictions.extend(pred_capacity)
            all_actuals.extend(actual_capacity)
            
            # Store sequence information if available
            if sequence_info is not None:
                start_idx = batch_idx * data_loader.batch_size
                end_idx = start_idx + len(batch_X)
                all_info.extend(sequence_info[start_idx:end_idx])

    return all_predictions, all_actuals, all_info

def save_predictions_to_csv(predictions, actuals, sequence_info, filename='battery_predictions.csv'):
    """Save predictions and actual values to CSV file"""
    # Create DataFrame
    results_df = pd.DataFrame({
        'battery_id': [info['battery_id'] for info in sequence_info],
        'cycle': [info['actual_cycle'] for info in sequence_info],
        'actual_capacity': actuals,
        'predicted_capacity': predictions,
        'absolute_error': np.abs(np.array(actuals) - np.array(predictions)),
        'relative_error_percent': (np.abs(np.array(actuals) - np.array(predictions)) / np.array(actuals)) * 100
    })
    
    # Save to CSV
    results_df.to_csv(filename, index=False)
    print(f"✅ Predictions saved to {filename}")
    print(f"📊 File contains {len(results_df)} predictions")
    
    return results_df

def plot_battery_degradation(combined_df):
    """Plot capacity degradation for all batteries"""
    plt.figure(figsize=(12, 8))

    for battery_id in combined_df['battery_id'].unique():
        battery_data = combined_df[combined_df['battery_id'] == battery_id].sort_values('cycle')
        plt.plot(battery_data['cycle'], battery_data['capacity'], label=battery_id, marker='o', markersize=2)

    plt.xlabel('Cycle Number')
    plt.ylabel('Capacity (Ah)')
    plt.title('Battery Capacity Degradation - All Batteries')
    plt.legend()
    plt.grid(True)
    plt.show()


def main():
    # Configuration
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Set output directory for CSV files
    output_dir = '/content/drive/MyDrive/DL_Project'
    os.makedirs(output_dir, exist_ok=True)
    print(f"Output directory: {output_dir}")

    # Hyperparameters
    sequence_length = 20
    hidden_size = 64
    num_layers = 2
    output_size = 1
    batch_size = 32
    num_epochs = 100
    learning_rate = 0.001
    train_ratio = 0.7
    val_ratio = 0.15

    # Initialize data processor
    processor = BatteryDataProcessor(sequence_length=sequence_length)

    # Load real battery data from your directory
    data_directory = '/content/drive/MyDrive/DL_Project/battery_data/5. Battery Data Set/1. BatteryAgingARC-FY08Q4'

    print("Loading real battery data...")
    all_battery_data = processor.load_all_battery_files(data_directory)

    if all_battery_data:
        print(f"Successfully loaded data from {len(all_battery_data)} batteries")
        combined_df = processor.combine_battery_data(all_battery_data)
        plot_battery_degradation(combined_df)
    else:
        print("No real data loaded. Generating synthetic data...")
        combined_df = processor.prepare_synthetic_data(num_cycles=300)

    print(f"Final data shape: {combined_df.shape}")
    print(f"Available columns: {combined_df.columns.tolist()}")
    print(f"Batteries: {combined_df['battery_id'].unique()}")

    # Create sequences for LSTM (now returns sequence_info as well)
    sequences, targets, sequence_info = processor.create_sequences(combined_df)
    print(f"Sequences shape: {sequences.shape}")
    print(f"Targets shape: {targets.shape}")

    if len(sequences) == 0:
        print("No sequences created. Check your data extraction.")
        return

    # Convert to PyTorch tensors
    X_tensor = torch.FloatTensor(sequences)
    y_tensor = torch.FloatTensor(targets)

    # Create dataset and split
    dataset = TensorDataset(X_tensor, y_tensor)
    train_size = int(train_ratio * len(dataset))
    val_size = int(val_ratio * len(dataset))
    test_size = len(dataset) - train_size - val_size

    train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])
    
    # Split sequence_info accordingly
    train_info = [sequence_info[i] for i in train_dataset.indices]
    val_info = [sequence_info[i] for i in val_dataset.indices]
    test_info = [sequence_info[i] for i in test_dataset.indices]

    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    # Initialize model
    input_size = sequences.shape[2]
    model = BatteryLSTM(
        input_size=input_size,
        hidden_size=hidden_size,
        num_layers=num_layers,
        output_size=output_size,
        dropout_rate=0.3
    ).to(device)

    print(f"Model architecture:")
    print(model)
    print(f"Input size: {input_size}")
    print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")

    # Train model
    trainer = BatteryTrainer(model, device)
    print("Starting training...")
    trainer.train(train_loader, val_loader, num_epochs, learning_rate)

    # Plot training history
    trainer.plot_training_history()

    # Evaluate model on test set
    print("Evaluating model on test set...")
    test_predictions, test_actuals, test_metrics, test_sequence_info = evaluate_model(
        model, test_loader, processor, device, test_info
    )

    # --------------------------------------------------
    # 📊 ONLY CREATE all_battery_predictions.csv
    # --------------------------------------------------
    print("\n📈 Generating predictions for entire dataset...")
    
    # Create data loader for entire dataset
    full_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    
    # Get predictions for all data
    all_predictions, all_actuals, all_info = get_all_predictions(
        model, full_loader, processor, device, sequence_info
    )
    
    # Save ONLY all predictions to CSV in the specified directory
    all_predictions_path = os.path.join(output_dir, 'all_battery_predictions.csv')
    results_df = save_predictions_to_csv(all_predictions, all_actuals, all_info, all_predictions_path)
    
    # Display first few rows of the results
    print("\n📋 First 10 rows of all predictions:")
    print(results_df.head(10))

    # --------------------------------------------------
    # 📊 ADDITIONAL PERFORMANCE VISUALIZATIONS
    # --------------------------------------------------
    import seaborn as sns

    # Use test set predictions for visualizations
    predictions = np.array(test_predictions)
    actuals = np.array(test_actuals)

    # ---- Scatter plot: Predicted vs Actual ----
    plt.figure(figsize=(6, 6))
    sns.scatterplot(x=actuals, y=predictions, alpha=0.7)
    plt.plot([actuals.min(), actuals.max()], [actuals.min(), actuals.max()], 'r--', lw=2)
    plt.xlabel("Actual Capacity (Ah)")
    plt.ylabel("Predicted Capacity (Ah)")
    plt.title("Predicted vs Actual Capacity")
    plt.grid(True)
    plt.show()

    # ---- Line plot: First N samples ----
    N = 100  # You can adjust this number
    plt.figure(figsize=(10, 5))
    plt.plot(actuals[:N], label='Actual', marker='o', markersize=3)
    plt.plot(predictions[:N], label='Predicted', marker='x', markersize=3)
    plt.xlabel("Sample Index")
    plt.ylabel("Capacity (Ah)")
    plt.title(f"Actual vs Predicted Capacity (First {N} Test Samples)")
    plt.legend()
    plt.grid(True)
    plt.show()

    # ---- Residual plot ----
    residuals = actuals - predictions
    plt.figure(figsize=(8, 5))
    sns.histplot(residuals, kde=True, color='skyblue')
    plt.xlabel("Residual (Actual - Predicted)")
    plt.title("Residual Distribution")
    plt.grid(True)
    plt.show()

    # ---- Metrics Summary Table ----
    metrics_df = pd.DataFrame({
        "Metric": ["MAE", "RMSE", "R²", "MSE"],
        "Value": [test_metrics['mae'], test_metrics['rmse'], test_metrics['r2'], test_metrics['mse']]
    })

    fig, ax = plt.subplots(figsize=(5, 1.5))
    ax.axis('off')
    table = ax.table(cellText=metrics_df.values, colLabels=metrics_df.columns,
                     loc='center', cellLoc='center')
    table.auto_set_font_size(False)
    table.set_fontsize(10)
    table.scale(1.2, 1.2)
    plt.title("Model Performance Summary", pad=10)
    plt.show()

    # --------------------------------------------------
    # 📉 Original comparison plots (optional)
    # --------------------------------------------------
    plt.figure(figsize=(15, 5))

    plt.subplot(1, 3, 1)
    plt.plot(test_actuals, label='Actual Capacity', alpha=0.7)
    plt.plot(test_predictions, label='Predicted Capacity', alpha=0.7)
    plt.xlabel('Test Samples')
    plt.ylabel('Capacity (Ah)')
    plt.title('Predictions vs Actual (Test Set)')
    plt.legend()
    plt.grid(True)

    plt.subplot(1, 3, 2)
    for battery_id in combined_df['battery_id'].unique():
        battery_data = combined_df[combined_df['battery_id'] == battery_id].sort_values('cycle')
        plt.plot(battery_data['cycle'], battery_data['capacity'], label=battery_id, alpha=0.7)
    plt.xlabel('Cycle Number')
    plt.ylabel('Capacity (Ah)')
    plt.title('Battery Capacity Degradation')
    plt.legend()
    plt.grid(True)

    plt.subplot(1, 3, 3)
    plt.plot(trainer.train_losses, label='Training Loss')
    plt.plot(trainer.val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training History')
    plt.legend()
    plt.grid(True)

    plt.tight_layout()
    plt.show()

    # Save model to the output directory
    model_path = os.path.join(output_dir, 'battery_lstm_model.pth')
    torch.save({
        'model_state_dict': model.state_dict(),
        'processor_scaler': processor.scaler,
        'processor_capacity_scaler': processor.capacity_scaler,
        'sequence_length': sequence_length,
        'input_size': input_size
    }, model_path)

    print("✅ Model and visualizations complete!")
    print(f"📁 All files saved to: {output_dir}")
    print(f"📊 CSV files created:")
    print(f"   - all_battery_predictions.csv (all data)")  # Only this file is created


if __name__ == "__main__":
    main()