In [2]:
import sys
import torch
from torch import nn
import joblib  # For loading the scaler
from PyQt5.QtWidgets import (
    QApplication,
    QWidget,
    QVBoxLayout,
    QLabel,
    QPushButton,
    QComboBox,
    QScrollArea,
    QStackedWidget,
    QHBoxLayout,
    QGroupBox,
    QFileDialog,
    QMessageBox,
)
from PyQt5.QtGui import QPixmap, QPainter, QColor, QFont
from PyQt5.QtCore import Qt, QPoint

# Define your trained model class
class PowerConsumptionModel(nn.Module):
    def __init__(self, input_size, output_size):
        super(PowerConsumptionModel, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(input_size, 256),
            nn.BatchNorm1d(256),
            nn.ELU(),
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ELU(),
            nn.Linear(128, 64),
            nn.LeakyReLU(negative_slope=0.02),
            nn.Linear(64, output_size),
            nn.Softplus(),
        )

    def forward(self, x):
        return self.network(x)

# Encoding Dictionaries
encodings = {
    "City": {
        "bagalkot": 0,
        "ballari": 1,
        "bangalore": 2,
        "belgaum": 3,
        "bidar": 4,
        "bijapur": 5,
        "chikkaballapura": 6,
        "chikkamagaluru": 7,
        "chikmagalur": 8,
        "chitradurga": 9,
        "dakshina kannada": 10,
        "davanagere": 11,
        "dharwad": 12,
        "gadag": 13,
        "hassan": 14,
        "haveri": 15,
        "hubli": 16,
        "kalaburagi": 17,
        "kodagu": 18,
        "kolar": 19,
        "koppal": 20,
        "mandya": 21,
        "mangalore": 22,
        "mysore": 23,
        "raichur": 24,
        "ramanagara": 25,
        "shivamogga": 26,
        "tumkur": 27,
        "udupi": 28,
        "yadgir": 29,
    },
    "State": {"karnataka": 0},
    "Month": {
        "april": 0,
        "august": 1,
        "december": 2,
        "february": 3,
        "january": 4,
        "july": 5,
        "june": 6,
        "march": 7,
        "may": 8,
        "november": 9,
        "october": 10,
        "september": 11,
    },
    "Weather Condition": {
        "cloudy": 0,
        "cold": 1,
        "cool": 2,
        "hot": 3,
        "rainy": 4,
        "sunny": 5,
        "warm": 6,
    },
    "Power Supply Status": {"insufficient": 0, "overflow": 1, "sufficient": 2},
}

input_columns = [
    "City",
    "State",
    "Month",
    "Weather Condition",
    "Temperature (°C)",
    "Humidity (%)",
    "Rainfall Chances (%)",
    "Wind Speed (km/h)",
    "Fan Power Consumed (kWh)",
    "Light Power Consumed (kWh)",
    "Mixer Power Consumed (kWh)",
    "Washing Machine Power Consumed (kWh)",
    "Phone Charging Power Consumed (kWh)",
    "UPS Power Consumed (kWh)",
    "Grinder Power Consumed (kWh)",
    "AC Power Consumed (kWh)",
    "Heater Power Consumed (kWh)",
    "Fridge Power Consumed (kWh)",
    "TV Power Consumed (kWh)",
]
target_columns = [
    "Total Power Consumed (kWh)",
    "Required Power Supply (kW)",
    "Current Power Supply (kW)",
    "Power Supply Status",
]

# First Page: Month Selection
class MonthSelectionPage(QWidget):
    def __init__(self, on_month_selected):
        super().__init__()
        self.on_month_selected = on_month_selected
        layout = QVBoxLayout()

        month_label = QLabel("Select Month:")
        month_label.setAlignment(Qt.AlignCenter)  # Center align
        month_label.setStyleSheet(
            "font-size: 18px; font-weight: bold; color: #333;"
        )  # Style

        self.month_combo = QComboBox()
        self.month_combo.addItems(encodings["Month"].keys())
        self.month_combo.setStyleSheet(
            "padding: 8px; border: 1px solid #ccc; border-radius: 4px;"
        )  # Style

        next_button = QPushButton("Next")
        next_button.clicked.connect(self.on_next_clicked)
        next_button.setStyleSheet(
            "background-color: #4CAF50; color: white; padding: 10px; border: none; border-radius: 4px;"
        )  # Style

        layout.addWidget(month_label)
        layout.addWidget(self.month_combo)
        layout.addWidget(next_button)

        self.setLayout(layout)
        self.setStyleSheet("background-color: #f0f0f0;")  # Background style

    def on_next_clicked(self):
        selected_month = self.month_combo.currentText()
        self.on_month_selected(selected_month)

# Second Page: Karnataka Map with District Predictions
class KarnatakaMapPage(QWidget):
    def __init__(self, model, scaler_x, scaler_y, selected_month):
        super().__init__()
        self.model = model
        self.scaler_x = scaler_x
        self.scaler_y = scaler_y
        self.selected_month = selected_month
        self.district_predictions = {}  # Store predictions for each district
        self.city_coordinates = {  # Approximate coordinates for each city
            "bagalkot": (50, 50),
            "ballari": (150, 100),
            "bangalore": (250, 150),
            "belgaum": (350, 200),
            "bidar": (450, 250),
            "bijapur": (550, 300),
            "chikkaballapura": (650, 350),
            "chikkamagaluru": (750, 400),
            "chikmagalur": (850, 450),
            "chitradurga": (950, 500),
            "dakshina kannada": (50, 550),
            "davanagere": (150, 600),
            "dharwad": (250, 650),
            "gadag": (350, 700),
            "hassan": (450, 750),
            "haveri": (550, 800),
            "hubli": (650, 850),
            "kalaburagi": (750, 900),
            "kodagu": (850, 950),
            "kolar": (950, 1000),
            "koppal": (50, 1050),
            "mandya": (150, 1100),
            "mangalore": (250, 1150),
            "mysore": (350, 1200),
            "raichur": (450, 1250),
            "ramanagara": (550, 1300),
            "shivamogga": (650, 1350),
            "tumkur": (750, 1400),
            "udupi": (850, 1450),
            "yadgir": (950, 1500),
        }  # Approximate coordinates for each city

        self.init_ui()

    def init_ui(self):
        layout = QVBoxLayout()

        # Load Karnataka Map
        self.original_pixmap = QPixmap("karnataka_map.png")  # Load the map image
        self.image_label = QLabel()
        self.image_label.setPixmap(
            self.original_pixmap.scaledToWidth(600)
        )  # Adjust size as needed
        self.image_label.setAlignment(Qt.AlignCenter)  # Center align image

        # Button to trigger prediction for all districts
        predict_button = QPushButton("Predict Power Consumption for All Districts")
        predict_button.clicked.connect(self.predict_all_districts)
        predict_button.setStyleSheet(
            "background-color: #007BFF; color: white; padding: 10px; border: none; border-radius: 4px;"
        )  # Style

        self.results_label = QLabel()  # To display overall results
        self.results_label.setAlignment(Qt.AlignCenter)
        self.results_label.setStyleSheet(
            "font-size: 16px; font-weight: bold; color: #555;"
        )  # Style

        layout.addWidget(self.image_label)
        layout.addWidget(predict_button)
        layout.addWidget(self.results_label)

        # Add scale/legend in the corners
        self.legend_label = QLabel()
        self.create_legend()  # Call the method to create the legend
        layout.addWidget(self.legend_label)

        # Scroll area to accommodate results
        scroll_area = QScrollArea()
        self.results_widget = QWidget()  # Widget to hold district results
        self.results_layout = QVBoxLayout()
        self.results_widget.setLayout(self.results_layout)
        scroll_area.setWidgetResizable(True)
        scroll_area.setWidget(self.results_widget)

        layout.addWidget(scroll_area)
        self.setLayout(layout)
        self.setStyleSheet("background-color: #e6f7ff;")  # Page Background

    def create_legend(self):
        # Create a legend pixmap
        legend_pixmap = QPixmap(300, 120)  # Adjust size as needed
        legend_pixmap.fill(Qt.white)  # Fill with white background

        # Create a painter for the legend
        painter = QPainter(legend_pixmap)
        painter.setFont(QFont("Arial", 10))  # Setting font

        # Draw colored dots and labels
        dot_size = 12
        x_offset = 15
        y_offset = 25
        line_spacing = 25

        # Red dot - Total Power Consumed
        painter.setBrush(QColor("red"))
        painter.drawEllipse(x_offset, y_offset, dot_size, dot_size)
        painter.drawText(
            int(x_offset + dot_size + 10),  # Corrected: int() conversion
            int(y_offset + dot_size / 2 + 4),  # Corrected: int() conversion
            "Total Power Consumed (kWh)",
        )  # Adjust text position

        # Blue dot - Required Power Supply
        painter.setBrush(QColor("blue"))
        painter.drawEllipse(x_offset, y_offset + line_spacing, dot_size, dot_size)
        painter.drawText(
            int(x_offset + dot_size + 10),  # Corrected: int() conversion
            int(y_offset + line_spacing + dot_size / 2 + 4),  # Corrected: int()
            "Required Power Supply (kW)",
        )  # Adjust text position

        # Green dot - Current Power Supply
        painter.setBrush(QColor("green"))
        painter.drawEllipse(
            x_offset, y_offset + 2 * line_spacing, dot_size, dot_size
        )
        painter.drawText(
            int(x_offset + dot_size + 10),  # Corrected: int() conversion
            int(y_offset + 2 * line_spacing + dot_size / 2 + 4),  # Corrected: int()
            "Current Power Supply (kW)",
        )  # Adjust text position

        # End painting
        painter.end()

        # Set the legend pixmap to the legend label
        self.legend_label.setPixmap(legend_pixmap)
        self.legend_label.setAlignment(Qt.AlignBottom | Qt.AlignRight)  # Right corner

    def predict_all_districts(self):
        # Clear previous results
        for i in reversed(range(self.results_layout.count())):
            widget = self.results_layout.itemAt(i).widget()
            if widget is not None:
                widget.deleteLater()

        # Create a copy of the original pixmap for drawing
        self.current_pixmap = self.original_pixmap.copy()
        painter = QPainter(self.current_pixmap)
        painter.setFont(QFont("Arial", 8))

        total_power_consumed = 0
        district_count = 0

        for city in encodings["City"].keys():  # Iterate through each district
            input_data = self.create_input_data(city, self.selected_month)
            prediction = self.predict_power_consumption(input_data)
            if prediction:
                total_power_consumed += prediction[0]  # Accumulate total power
                district_count += 1

                # Get coordinates for the city
                x, y = self.city_coordinates.get(city, (0, 0))

                # Draw dots on the map
                dot_size = 6  # Adjust size as needed
                x_offset = 10  # Adjust horizontal offset for the dots
                painter.setBrush(QColor("red"))
                painter.drawEllipse(
                    int(x + x_offset), int(y), dot_size, dot_size
                )  # Total Power Consumed
                painter.setBrush(QColor("blue"))
                painter.drawEllipse(
                    int(x + x_offset + 10), int(y), dot_size, dot_size
                )  # Required Power
                painter.setBrush(QColor("green"))
                painter.drawEllipse(
                    int(x + x_offset + 20), int(y), dot_size, dot_size
                )  # Current Power

                # Display results for each district
                result_label = QLabel(
                    f"<b>{city.title()}:</b> Total Power Consumed: {prediction[0]:.2f} kWh, Required Power: {prediction[1]:.2f} kW, Current Power: {prediction[2]:.2f} kW, Status: {self.get_power_supply_status(prediction)}"
                )
                self.results_layout.addWidget(result_label)

        # Calculate average power consumption
        if district_count > 0:
            average_power_consumed = total_power_consumed / district_count
            self.results_label.setText(
                f"<b>Overall Average Power Consumption:</b> {average_power_consumed:.2f} kWh"
            )
        else:
            self.results_label.setText("No districts processed.")

        # End painting
        painter.end()

        # Update the image label with the modified pixmap
        self.image_label.setPixmap(self.current_pixmap.scaledToWidth(600))

    def create_input_data(self, city, selected_month):
        # Autofill the input data, including encoded values
        state = "karnataka"
        # Vary weather based on district (example)
        if city in ["bangalore", "mysore"]:
            weather_condition = "cool"
        elif city in ["ballari", "raichur"]:
            weather_condition = "hot"
        else:
            weather_condition = "sunny"  # Default weather

        # Vary numerical features based on district (example)
        if city == "bangalore":
            temperature = 25.0
            humidity = 70.0
        elif city == "mysore":
            temperature = 27.0
            humidity = 75.0
        else:
            temperature = 30.0  # Default temperature
            humidity = 60.0  # Default humidity

        rainfall_chances = 5.0
        wind_speed = 10.0
        fan_power = 0.1
        light_power = 0.05
        mixer_power = 0.2
        washing_machine_power = 0.5
        phone_charging_power = 0.01
        ups_power = 0.3
        grinder_power = 0.25
        ac_power = 1.5
        heater_power = 0.0
        fridge_power = 0.4
        tv_power = 0.2

        input_data = [
            encodings["City"][city],
            encodings["State"][state],
            encodings["Month"][selected_month],
            encodings["Weather Condition"][weather_condition],
            temperature,
            humidity,
            rainfall_chances,
            wind_speed,
            fan_power,
            light_power,
            mixer_power,
            washing_machine_power,
            phone_charging_power,
            ups_power,
            grinder_power,
            ac_power,
            heater_power,
            fridge_power,
            tv_power,
        ]
        return input_data

    def predict_power_consumption(self, input_data):
        # Scale the input data
        scaled_data = self.scaler_x.transform([input_data])

        # Convert to tensor
        input_tensor = torch.tensor(scaled_data, dtype=torch.float32)

        # Make prediction
        self.model.eval()
        with torch.no_grad():
            prediction = self.model(input_tensor)

        # Inverse transform
        inverse_prediction = self.scaler_y.inverse_transform(prediction.numpy())
        return inverse_prediction.squeeze().tolist()

    def get_power_supply_status(self, predictions):
        # Determine Power Supply Status
        required_power = predictions[1]
        current_power = predictions[2]
        if required_power > current_power:
            return "Underflow"
        elif required_power < current_power:
            return "Overflow"
        else:
            return "Sufficient"

# Main Window
class MainWindow(QWidget):
    def __init__(self, model, scaler_x, scaler_y):
        super().__init__()
        self.model = model
        self.scaler_x = scaler_x
        self.scaler_y = scaler_y
        self.stacked_widget = QStackedWidget()

        # Initialize pages
        self.month_selection_page = MonthSelectionPage(
            self.show_karnataka_map_page
        )  # Corrected here
        self.karnataka_map_page = None  # Initialize to None

        self.stacked_widget.addWidget(self.month_selection_page)
        if self.karnataka_map_page is not None:
            self.stacked_widget.addWidget(self.karnataka_map_page)

        layout = QVBoxLayout()
        layout.addWidget(self.stacked_widget)
        self.setLayout(layout)
        self.setWindowTitle("Power Consumption Predictor")  # Window title
        self.setStyleSheet("background-color: #f5f5dc;")  # Overall Background

    def show_karnataka_map_page(self, selected_month):
        # Create the Karnataka map page, passing in the model, scaler, and
        # selected month
        self.karnataka_map_page = KarnatakaMapPage(
            self.model, self.scaler_x, self.scaler_y, selected_month
        )

        # Add the page to the stacked widget
        self.stacked_widget.addWidget(self.karnataka_map_page)

        # Switch to the Karnataka map page
        self.stacked_widget.setCurrentWidget(self.karnataka_map_page)

if __name__ == "__main__":
    app = QApplication(sys.argv)

    # Load the model
    input_size = len(input_columns)
    output_size = len(target_columns)  # Exclude 'Power Supply Status'

    model = PowerConsumptionModel(input_size, output_size)
    try:
        model.load_state_dict(torch.load("best_model.pth"))
    except FileNotFoundError:
        QMessageBox.critical(
            None,
            "Model Not Found",
            "Could not find 'best_model.pth'.  Please ensure it is in the correct directory.",
        )
        sys.exit(1)  # Exit if model not found

    # Load scalers
    try:
        scaler_x = joblib.load("scaler_X.pkl")
        scaler_y = joblib.load("scaler_y.pkl")
    except FileNotFoundError:
        QMessageBox.critical(
            None,
            "Scaler Not Found",
            "Could not find 'scaler_X.pkl' or 'scaler_y.pkl'.  Please ensure they are in the correct directory.",
        )
        sys.exit(1)  # Exit if scaler not found

    window = MainWindow(model, scaler_x, scaler_y)
    window.setWindowTitle("Power Consumption Predictor")
    window.setGeometry(100, 100, 800, 600)  # Increased size
    window.show()

    sys.exit(app.exec_())


  model.load_state_dict(torch.load("best_model.pth"))


SystemExit: 0

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
