In [None]:
import sys
import os
import sqlite3
import json
import numpy as np
import threading
import matplotlib
matplotlib.use('Qt5Agg')  # Ensure compatibility with PyQt5
import matplotlib.pyplot as plt
from PyQt5.QtWidgets import QApplication, QWidget, QLabel, QLineEdit, QVBoxLayout, QHBoxLayout, QPushButton, QComboBox
from PyQt5.QtGui import QPixmap
from PyQt5.QtCore import Qt, QThread, pyqtSignal
import torch
print("CUDA Available:", torch.cuda.is_available())

from ultralytics import YOLO
model = YOLO("yolov8n.pt", device="cpu")  # Force CPU mode



class HeatmapGUI(QWidget):
    def __init__(self):
        super().__init__()
        self.local_path = os.getcwd()
        self.available_db = self.load_db_files()
        self.selected_db = self.available_db[0] if self.available_db else None
        self.heatmap_path = None
        self.yolo_model = None  # Placeholder for YOLO model
        self.initUI()

    def initUI(self):
        main_layout = QVBoxLayout()
        input_layout = QHBoxLayout()

        self.label_low = QLabel('Enter lower range:')
        self.input_low = QLineEdit()
        input_layout.addWidget(self.label_low)
        input_layout.addWidget(self.input_low)

        self.label_high = QLabel('Enter higher range:')
        self.input_high = QLineEdit()
        input_layout.addWidget(self.label_high)
        input_layout.addWidget(self.input_high)
        main_layout.addLayout(input_layout)

        self.dataset_selector = QComboBox()
        self.dataset_selector.addItems(self.available_db)
        self.dataset_selector.currentIndexChanged.connect(self.change_selected_db)
        main_layout.addWidget(self.dataset_selector)

        self.btn_generate = QPushButton('Generate Heatmap')
        self.btn_generate.clicked.connect(self.generate_heatmap)
        main_layout.addWidget(self.btn_generate)

        self.btn_train = QPushButton('Train YOLO Model')
        self.btn_train.clicked.connect(self.train_yolo_model)
        main_layout.addWidget(self.btn_train)

        self.btn_run = QPushButton('Run YOLO Detection')
        self.btn_run.clicked.connect(self.run_yolo_detection)
        main_layout.addWidget(self.btn_run)

        self.image_label = QLabel("Heatmap Output")
        self.image_label.setAlignment(Qt.AlignCenter)
        self.image_label.setStyleSheet("border: 1px solid black;")
        self.image_label.setFixedSize(800, 800)
        main_layout.addWidget(self.image_label)

        self.setLayout(main_layout)
        self.setWindowTitle('Heatmap Generator with YOLO')
        self.show()

    def load_db_files(self):
        return [file for file in os.listdir(self.local_path) if file.endswith(".db")]

    def change_selected_db(self):
        self.selected_db = self.dataset_selector.currentText()

    def generate_heatmap(self):
        if not self.selected_db:
            print("❌ No database selected.")
            return

        try:
            lower_range = int(self.input_low.text())
            higher_range = int(self.input_high.text())

            conn = sqlite3.connect(self.selected_db)
            cursor = conn.cursor()
            cursor.execute("SELECT * FROM data")
            data = cursor.fetchall()
            conn.close()

            if not data:
                print("❌ No data found in the database.")
                return

            data_matrix = np.array(data).T  # Transpose, as in MATLAB

            if data_matrix.shape[1] <= higher_range:
                print("❌ Selected range exceeds dataset size.")
                return

            subset_data = data_matrix[:, lower_range:higher_range]

            # Extract Real & Imaginary Data
            real_data = subset_data[15:29, :]
            bz_real = real_data[:8, :]
            bx_real = real_data[8:14, :]

            imaginary_data = subset_data[1:15, :]
            bz_imaginary = imaginary_data[:8, :]
            bx_imaginary = imaginary_data[8:14, :]

            # Compute Magnitudes
            bx_magnitude = np.sqrt(bx_real ** 2 + bx_imaginary ** 2)
            bz_magnitude = np.sqrt(bz_real ** 2 + bz_imaginary ** 2)

            # Normalize
            bz_ref = bz_magnitude[:, 1]
            bz_ref = np.where(bz_ref == 0, 1, bz_ref)
            bz_normalized = (bz_magnitude - bz_ref[:, None]) / bz_ref[:, None]

            bx_ref = bx_magnitude[:, 1]
            bx_ref = np.where(bx_ref == 0, 1, bx_ref)
            bx_normalized = (bx_magnitude - bx_ref[:, None]) / bx_ref[:, None]

            # Plot Heatmap
            fig, axes = plt.subplots(2, 1, figsize=(8, 8))
            axes[0].imshow(bx_normalized, cmap='jet', aspect='auto')
            axes[1].imshow(bz_normalized, cmap='jet', aspect='auto')

            filename = "heatmap.png"
            plt.savefig(filename)
            self.image_label.setPixmap(QPixmap(filename).scaled(800, 800, Qt.KeepAspectRatio))
            self.heatmap_path = filename
            print("✅ Heatmap saved.")

        except Exception as e:
            print(f"❌ Error: {e}")

    def train_yolo_model(self):
        """Train YOLO model in a separate thread to prevent UI freezing."""
        if YOLO is None:
            print("❌ YOLO library is not installed.")
            return

        self.thread = QThread()
        self.worker = TrainYOLOWorker()
        self.worker.moveToThread(self.thread)
        self.thread.started.connect(self.worker.run)
        self.worker.finished.connect(self.thread.quit)
        self.worker.finished.connect(self.worker.deleteLater)
        self.thread.finished.connect(self.thread.deleteLater)
        self.thread.start()

    def run_yolo_detection(self):
        """Run YOLO detection safely."""
        if YOLO is None:
            print("❌ YOLO library is not installed.")
            return
        if not self.heatmap_path or not os.path.exists(self.heatmap_path):
            print("❌ No heatmap available for detection.")
            return

        self.yolo_model = YOLO("yolov8n.pt")
        results = self.yolo_model(self.heatmap_path)

        result_img = results[0].plot()
        result_filename = "yolo_heatmap.png"
        plt.imsave(result_filename, result_img)
        self.image_label.setPixmap(QPixmap(result_filename).scaled(800, 800, Qt.KeepAspectRatio))
        print(f"✅ Detection results saved as {result_filename}")


class TrainYOLOWorker(QThread):
    finished = pyqtSignal()

    def run(self):
        try:
            print("🔹 Training YOLO model...")
            yolo_model = YOLO("yolov8n.pt")
            yolo_model.train(data="dataset.yaml", epochs=10, imgsz=640)
            print("✅ YOLO model trained successfully!")
        except Exception as e:
            print(f"❌ Training failed: {e}")
        finally:
            self.finished.emit()


if __name__ == '__main__':
    app = QApplication(sys.argv)
    ex = HeatmapGUI()
    sys.exit(app.exec_())


: 