In [None]:
import sys
import os
import requests
import torch
from torch import nn
import joblib
from PyQt5.QtWidgets import (
    QApplication, QMainWindow, QWidget, QVBoxLayout, QLabel, QHBoxLayout,
    QFrame, QScrollArea, QGridLayout, QMessageBox
)
from PyQt5.QtCore import Qt, QThread, pyqtSignal
from PyQt5.QtGui import QPixmap, QPainter, QColor, QPen, QBrush
from sklearn.preprocessing import LabelEncoder
import numpy as np

# Utility function to get the script directory
def get_script_dir():
    if getattr(sys, 'frozen', False):
        # If the application is run as a bundle (e.g. using PyInstaller)
        script_dir = os.path.dirname(sys.executable)
    else:
        script_dir = os.path.dirname(os.path.abspath(sys.argv[0]))
    return script_dir

# Define your trained model class (as before)
class PowerConsumptionModel(nn.Module):
    def __init__(self, input_size, output_size):
        super(PowerConsumptionModel, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(input_size, 256),
            nn.BatchNorm1d(256),
            nn.ELU(),
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ELU(),
            nn.Linear(128, 64),
            nn.LeakyReLU(negative_slope=0.02),
            nn.Linear(64, output_size),
            nn.Softplus()
        )

    def forward(self, x):
        return self.network(x)

# Worker Thread for Predictions
class PredictionWorker(QThread):
    finished = pyqtSignal(dict)  # Signal to send results

    def __init__(self, districts, model, scaler_x, scaler_y, api_key, city_encodings, state_encodings, month_encodings, weather_encodings):
        super().__init__()
        self.districts = districts
        self.model = model
        self.scaler_x = scaler_x
        self.scaler_y = scaler_y
        self.api_key = api_key
        self.city_encodings = city_encodings
        self.state_encodings = state_encodings
        self.month_encodings = month_encodings
        self.weather_encodings = weather_encodings

    def get_weather_data(self, district):
        try:
            url = f"http://api.openweathermap.org/data/2.5/weather"
            params = {
                'q': f"{district},Karnataka,IN",
                'appid': self.api_key,
                'units': 'metric'
            }
            response = requests.get(url, params=params, timeout=10)
            response.raise_for_status()  # Raise HTTPError for bad responses
            data = response.json()
            return {
                'temperature': data['main']['temp'],
                'humidity': data['main']['humidity'],
                'weather': data['weather'][0]['main'],
                'wind_speed': data['wind']['speed']
            }
        except requests.exceptions.RequestException as e:
            print(f"Weather API error for {district}: {e}")
            return None
        except (KeyError, ValueError) as e:
            print(f"Error processing weather data for {district}: {e}")
            return None

    def predict_power(self, weather_data, month, district):
        try:
            input_features = np.zeros((1, 19))

            # Basic features
            input_features[0, 0] = self.city_encodings[district]
            input_features[0, 1] = self.state_encodings['karnataka']  # Hardcoded
            input_features[0, 2] = month  # Assuming month is already encoded
            input_features[0, 3] = self.weather_encodings[weather_data['weather'].lower()]  #Weather

            input_features[0, 4] = weather_data['temperature']
            input_features[0, 5] = weather_data['humidity']
            input_features[0, 6] = 30 if 'rain' in weather_data['weather'].lower() else 10
            input_features[0, 7] = weather_data['wind_speed'] * 3.6

            # Appliance power consumption
            temp = weather_data['temperature']
            input_features[0, 8] = 0.5 if temp > 25 else 0.2  # Fan
            input_features[0, 9] = 0.8 if month in [6,7,8] else 1.2  # Light
            input_features[0, 10] = 0.2  # Mixer
            input_features[0, 11] = 0.5  # Washing Machine
            input_features[0, 12] = 0.1  # Phone
            input_features[0, 13] = 0.3  # UPS
            input_features[0, 14] = 0.4  # Grinder
            input_features[0, 15] = 2.0 if temp > 30 else 0.5  # AC
            input_features[0, 16] = 1.5 if temp < 20 else 0.0  # Heater
            input_features[0, 17] = 1.5  # Fridge
            input_features[0, 18] = 0.3  # TV

            # Scale and predict
            input_scaled = self.scaler_x.transform(input_features)
            input_tensor = torch.FloatTensor(input_scaled)

            with torch.no_grad():
                output = self.model(input_tensor)

            predictions = self.scaler_y.inverse_transform(output.numpy())

            return {
                'total_power': float(predictions[0, 0]),
                'required_supply': float(predictions[0, 1]),
                'current_supply': float(predictions[0, 2])
            }
        except Exception as e:
            print(f"Prediction error for {district}: {e}")
            return None

    def run(self):
        all_results = {}
        month = 1  # Replace with selected month if needed
        for district in self.districts:
            weather_data = self.get_weather_data(district)
            if weather_data:
                predictions = self.predict_power(weather_data, month, district)
                if predictions:
                    all_results[district] = {
                        'weather': weather_data,
                        'predictions': predictions
                    }
        self.finished.emit(all_results)  # Send results back to main thread

class PowerDashboard(QMainWindow):
    def __init__(self, model, input_size, scaler_x, scaler_y):
        super().__init__()
        self.model = model
        self.input_size = input_size
        self.scaler_x = scaler_x
        self.scaler_y = scaler_y
        self.API_KEY = "YOUR_API_KEY"  # Replace
        self.encodings = {'City': {'bagalkot': 0,
                                  'ballari': 1,
                                  'bangalore': 2,
                                  'belgaum': 3,
                                  'bidar': 4,
                                  'bijapur': 5,
                                  'chikkaballapura': 6,
                                  'chikkamagaluru': 7,
                                  'chikmagalur': 8,
                                  'chitradurga': 9,
                                  'dakshina kannada': 10,
                                  'davanagere': 11,
                                  'dharwad': 12,
                                  'gadag': 13,
                                  'hassan': 14,
                                  'haveri': 15,
                                  'hubli': 16,
                                  'kalaburagi': 17,
                                  'kodagu': 18,
                                  'kolar': 19,
                                  'koppal': 20,
                                  'mandya': 21,
                                  'mangalore': 22,
                                  'mysore': 23,
                                  'raichur': 24,
                                  'ramanagara': 25,
                                  'shivamogga': 26,
                                  'tumkur': 27,
                                  'udupi': 28,
                                  'yadgir': 29},
                                 'State': {'karnataka': 0},
                                 'Month': {'april': 0,
                                  'august': 1,
                                  'december': 2,
                                  'february': 3,
                                  'january': 4,
                                  'july': 5,
                                  'june': 6,
                                  'march': 7,
                                  'may': 8,
                                  'november': 9,
                                  'october': 10,
                                  'september': 11},
                                 'Weather Condition': {'cloudy': 0,
                                  'cold': 1,
                                  'cool': 2,
                                  'hot': 3,
                                  'rainy': 4,
                                  'sunny': 5,
                                  'warm': 6},
                                 'Power Supply Status': {'insufficient': 0, 'overflow': 1, 'sufficient': 2}}
        # Encoders (Initialize and fit as in your original code)
        self.city_encodings = self.encodings['City']
        self.state_encodings = self.encodings['State']
        self.month_encodings = self.encodings['Month']
        self.weather_encodings = self.encodings['Weather Condition']
        self.districts = list(self.city_encodings.keys())
        self.init_ui()
        self.start_predictions()

    def init_ui(self):
        self.setWindowTitle("Karnataka Power Dashboard")
        self.setGeometry(100, 100, 1400, 900)

        central_widget = QWidget()
        self.setCentralWidget(central_widget)
        main_layout = QVBoxLayout(central_widget)

        # Map Section
        map_frame = QFrame()
        map_layout = QVBoxLayout(map_frame)
        self.map_label = QLabel()

        # Use the get_script_dir function
        script_dir = get_script_dir()
        self.map_pixmap = QPixmap(os.path.join(script_dir, "Karnataka_map.png"))
        self.map_label.setPixmap(
            self.map_pixmap.scaled(700, 800, Qt.KeepAspectRatio)
        )
        map_layout.addWidget(self.map_label)
        main_layout.addWidget(map_frame)

        # Stats Section
        stats_frame = QFrame()
        stats_layout = QVBoxLayout(stats_frame)
        stats_header = QLabel("District Statistics")
        stats_layout.addWidget(stats_header)
        self.stats_scroll = QScrollArea()
        self.stats_content = QWidget()
        self.stats_grid = QGridLayout(self.stats_content)
        self.stats_content.setLayout(self.stats_grid)
        self.stats_scroll.setWidget(self.stats_content)
        self.stats_scroll.setWidgetResizable(True)
        stats_layout.addWidget(self.stats_scroll)
        main_layout.addWidget(stats_frame)
        stats_frame.setLayout(stats_layout)

    def start_predictions(self):
        # Start the prediction worker thread
        self.prediction_worker = PredictionWorker(
            self.districts, self.model, self.scaler_x, self.scaler_y, self.API_KEY,
            self.city_encodings, self.state_encodings, self.month_encodings, self.weather_encodings
        )
        self.prediction_worker.finished.connect(self.update_dashboard)
        self.prediction_worker.start()

    def update_dashboard(self, all_results):
        # Clear existing stats
        for i in reversed(range(self.stats_grid.count())):
            widget = self.stats_grid.itemAt(i).widget()
            if widget:
                widget.setParent(None)

        # Draw on Map
        # Use the get_script_dir function here too if needed
        script_dir = get_script_dir()
        result_pixmap = QPixmap(os.path.join(script_dir, "Karnataka_map.png"))
        painter = QPainter(result_pixmap)
        painter.setRenderHint(QPainter.Antialiasing)

        # Populate Stats and Draw Indicators
        for idx, district in enumerate(self.districts):
            if district in all_results:
                weather_data = all_results[district]['weather']
                predictions = all_results[district]['predictions']
                self.add_district_stats(idx, district, predictions, weather_data)
                self.draw_district_indicators(painter, district, predictions)

        painter.end()
        self.map_label.setPixmap(result_pixmap.scaled(700, 800, Qt.KeepAspectRatio))

    def add_district_stats(self, idx, district, predictions, weather_data):
      district_frame = QFrame()
      district_frame.setStyleSheet("""
          QFrame {
              background-color: #E3F2FD;
              padding: 15px;
              margin: 8px;
          }
          QLabel {
              font-size: 12px;
              color: #1A237E;
          }
      """)
      layout = QVBoxLayout(district_frame)

      name_label = QLabel(f"<b>{district}</b>")
      name_label.setStyleSheet("font-size: 14px; color: #1565C0;")
      layout.addWidget(name_label)

      stats_text = f"""
      <b>Weather:</b> {weather_data['weather']}
      <b>Temperature:</b> {weather_data['temperature']:.1f}°C
      <b>Humidity:</b> {weather_data['humidity']}%
      <b>Total Power:</b> {predictions['total_power']:.2f} kWh
      <b>Required Supply:</b> {predictions['required_supply']:.2f} kW
      <b>Current Supply:</b> {predictions['current_supply']:.2f} kW
      """
      stats_label = QLabel(stats_text)
      layout.addWidget(stats_label)

      self.stats_grid.addWidget(district_frame, idx // 2, idx % 2)

    def draw_district_indicators(self, painter, district, predictions):
        coordinates = {
            "Bagalkot": (350, 150),
            "Ballari": (400, 250),
            "Belagavi": (250, 120),
            "Bengaluru Rural": (450, 400),
            "Bengaluru Urban": (470, 420),
            "Bidar": (500, 100),
            "Chamarajanagar": (350, 500),
            "Chikkaballapur": (450, 350),
            "Chikkamagaluru": (300, 350),
            "Chitradurga": (350, 300),
            "Dakshina Kannada": (200, 400),
            "Davanagere": (300, 300),
            "Dharwad": (280, 150),
            "Gadag": (300, 180),
            "Hassan": (280, 400),
            "Haveri": (270, 200),
            "Kalaburagi": (480, 150),
            "Kodagu": (250, 450),
            "Kolar": (500, 400),
            "Koppal": (380, 200),
            "Mandya": (350, 450),
            "Mysuru": (320, 470),
            "Raichur": (450, 200),
            "Ramanagara": (400, 420),
            "Shivamogga": (250, 300),
            "Tumakuru": (380, 350),
            "Udupi": (200, 350),
            "Uttara Kannada": (220, 200),
            "Vijayapura": (320, 100),
            "Yadgir": (470, 180),
            "Vijayanagara": (380, 230),
        }

        if district in coordinates:
            x, y = coordinates[district]

            indicators = [
                ("total_power", QColor(255, 50, 50, 200)),  # Semi-transparent red
                ("required_supply", QColor(50, 255, 50, 200)),  # Semi-transparent green
                ("current_supply", QColor(50, 50, 255, 200)),  # Semi-transparent blue
            ]

            # Draw indicators with glow effect
            for i, (key, color) in enumerate(indicators):
                # Create glow effect
                glow_pen = QPen(color, 3)
                painter.setPen(glow_pen)
                painter.setBrush(QBrush(color))

                # Calculate position with spacing
                indicator_x = x + (i * 15)
                indicator_y = y

                # Draw main indicator
                painter.drawEllipse(indicator_x, indicator_y, 10, 10)

                # Add value label if needed
                value = predictions[key]
                if value > 100:  # Add warning indicator
                    painter.setPen(QPen(Qt.red, 2))
                    painter.drawEllipse(indicator_x - 2, indicator_y - 2, 14, 14)
                else:
                    pass  # Handle values that are not greater than 100

if __name__ == "__main__":
    input_size = 19  # Set the correct input size for your model
    output_size = 4  # Set the correct output size (total_power, required_supply, current_supply)

    # Load the model (ensure the path is correct)
    model = PowerConsumptionModel(input_size, output_size)
    try:
        model.load_state_dict(torch.load("best_model.pth"))
        model.eval() #Set the model to evaluation mode
        print("Model loaded successfully.")

    except Exception as e:
        print(f"Failed to load the model : {e}")

    # Load the scalers (ensure the paths are correct)
    try :
        script_dir = get_script_dir()
        scaler_x = joblib.load(os.path.join(script_dir, 'scaler_X.pkl'))
        scaler_y = joblib.load(os.path.join(script_dir, 'scaler_y.pkl'))
        print("Model and Scalers has been loaded.")

    except Exception as e:
        print(f"Failed to load the Scaler: {e}")

    app = QApplication(sys.argv)
    dashboard = PowerDashboard(model, input_size, scaler_x, scaler_y)
    dashboard.show()
    sys.exit(app.exec_())
