In [None]:
# -*- coding: utf-8 -*-
import sys
import time
import numpy as np
import pickle
from PyQt5.QtCore import Qt, pyqtSignal
from PyQt5.QtWidgets import (
    QApplication, QMainWindow, QDialog, QFileDialog,
    QLabel, QCheckBox, QDoubleSpinBox, QSpinBox,
    QPushButton, QVBoxLayout, QHBoxLayout, QFormLayout,
    QProgressBar, QMessageBox, QWidget
)
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.figure import Figure

# my functions
from Functions import DataGeneration

#######################
# Matplotlib Widget
#######################
class MatplotlibWidget(QWidget):
    def __init__(self, parent=None):
        super().__init__(parent)
        self.canvas = FigureCanvas(Figure(figsize=(6, 4)))
        layout = QVBoxLayout()
        layout.addWidget(self.canvas)
        self.setLayout(layout)
    def get_figure(self):
        return self.canvas.figure

#######################
# 1) Window1: TopDownDialog
#######################
class TopDownDialog(QDialog):
    submitted = pyqtSignal(dict)

    def __init__(self, parent=None):
        super().__init__(parent)
        self.setWindowTitle("Window1 - Top-Down Parameters")
        self.resize(400, 200)
        self.init_ui()

    def init_ui(self):
        layout = QVBoxLayout()
        form_layout = QFormLayout()

        self.check_top_down = QCheckBox("Use top-down")
        self.spin_S = QDoubleSpinBox()
        self.spin_S.setValue(0.3)
        self.spin_S.setDecimals(3)
        self.spin_S.setSingleStep(0.1)

        self.spin_R = QDoubleSpinBox()
        self.spin_R.setValue(1.0)
        self.spin_R.setDecimals(3)
        self.spin_R.setSingleStep(0.1)

        form_layout.addRow("Top-down:", self.check_top_down)
        form_layout.addRow("S value:", self.spin_S)
        form_layout.addRow("R value:", self.spin_R)
        layout.addLayout(form_layout)

        btn_next = QPushButton("Next")
        btn_next.clicked.connect(self.on_next_clicked)
        layout.addWidget(btn_next)
        self.setLayout(layout)

    def on_next_clicked(self):
        params = {
            'top_down': self.check_top_down.isChecked(),
            'S': self.spin_S.value(),
            'R': self.spin_R.value()
        }
        self.submitted.emit(params)
        self.accept()

#######################
# 2) Window2: ParamDialog
#######################
class ParamDialog(QDialog):
    submitted = pyqtSignal(dict)

    def __init__(self, parent=None):
        super().__init__(parent)
        self.setWindowTitle("Window2 - Other Parameters")
        self.resize(600, 300)
        self.init_ui()

    def init_ui(self):
        main_hlayout = QHBoxLayout()

        # 預設參數
        self.n_E_val = 60
        self.n_I_val = 50
        self.n_NSE_val = 280
        self.num_trial = 5
        self.gamma_ei_val = 0.0
        self.gamma_ie_val = 0.0
        self.gamma_ee_val = 1.25
        self.gamma_ii_val = 0.0
        self.bg_len_val = 200
        self.duration_val = 2000
        self.sensory_input_val = 0.00156
        self.threshold_val = 60

        # Column 1
        col1_layout = QFormLayout()
        self.spin_nE = QSpinBox()
        self.spin_nE.setMaximum(9999)
        self.spin_nE.setValue(self.n_E_val)
        col1_layout.addRow("n_E:", self.spin_nE)

        self.spin_nI = QSpinBox()
        self.spin_nI.setMaximum(999)
        self.spin_nI.setValue(self.n_I_val)
        col1_layout.addRow("n_I:", self.spin_nI)

        self.spin_nNSE = QSpinBox()
        self.spin_nNSE.setMaximum(999)
        self.spin_nNSE.setValue(self.n_NSE_val)
        col1_layout.addRow("n_NSE:", self.spin_nNSE)

        self.spin_gEE = QDoubleSpinBox()
        self.spin_gEE.setDecimals(3)
        self.spin_gEE.setValue(self.gamma_ee_val)
        col1_layout.addRow("gamma_ee:", self.spin_gEE)

        self.spin_gEI = QDoubleSpinBox()
        self.spin_gEI.setDecimals(3)
        self.spin_gEI.setValue(self.gamma_ei_val)
        col1_layout.addRow("gamma_ei:", self.spin_gEI)

        self.spin_gIE = QDoubleSpinBox()
        self.spin_gIE.setDecimals(3)
        self.spin_gIE.setValue(self.gamma_ie_val)
        col1_layout.addRow("gamma_ie:", self.spin_gIE)

        self.spin_gII = QDoubleSpinBox()
        self.spin_gII.setDecimals(3)
        self.spin_gII.setValue(self.gamma_ii_val)
        col1_layout.addRow("gamma_ii:", self.spin_gII)

        # Column 2
        col2_layout = QFormLayout()
        self.spin_num_trial = QSpinBox()
        self.spin_num_trial.setMaximum(999)
        self.spin_num_trial.setValue(self.num_trial)
        col2_layout.addRow("num_trial:", self.spin_num_trial)

        self.spin_threshold = QSpinBox()
        self.spin_threshold.setMaximum(150)
        self.spin_threshold.setValue(self.threshold_val)
        col2_layout.addRow("threshold:", self.spin_threshold)

        self.spin_trial_len = QSpinBox()
        self.spin_trial_len.setMaximum(9999)
        self.spin_trial_len.setValue(self.duration_val)
        col2_layout.addRow("trial_len (ms):", self.spin_trial_len)

        self.spin_bg_len = QSpinBox()
        self.spin_bg_len.setMaximum(999)
        self.spin_bg_len.setValue(self.bg_len_val)
        col2_layout.addRow("BG_len (ms):", self.spin_bg_len)

        self.spin_input_amp = QDoubleSpinBox()
        self.spin_input_amp.setDecimals(6)
        self.spin_input_amp.setValue(self.sensory_input_val)
        col2_layout.addRow("input_amp:", self.spin_input_amp)

        main_hlayout.addLayout(col1_layout)
        main_hlayout.addLayout(col2_layout)

        main_vlayout = QVBoxLayout()
        main_vlayout.addLayout(main_hlayout)

        btn_go_exec = QPushButton("Go to Execution")
        btn_go_exec.clicked.connect(self.on_submit_clicked)
        main_vlayout.addWidget(btn_go_exec)
        self.setLayout(main_vlayout)

    def on_submit_clicked(self):
        params = {
            'n_E': self.spin_nE.value(),
            'n_I': self.spin_nI.value(),
            'n_NSE': self.spin_nNSE.value(),
            'gamma_ee': self.spin_gEE.value(),
            'gamma_ei': self.spin_gEI.value(),
            'gamma_ie': self.spin_gIE.value(),
            'gamma_ii': self.spin_gII.value(),
        }
        other_dict = {
            'num_trial': self.spin_num_trial.value(),
            'threshold': self.spin_threshold.value(),
            'trial_len': self.spin_trial_len.value(),
            'BG_len': self.spin_bg_len.value(),
            'input_amp': self.spin_input_amp.value()
        }
        merged = {'params': params, 'others': other_dict}
        self.submitted.emit(merged)
        self.accept()

#########################################
# 3) Window3: ExecDialog
#########################################
class ExecDialog(QDialog):
    """
    Execution dialog:
    1) run trials with different coherence levels
    2) plot performance (no error bar) and RT (with error bar)
    3) after finishing, ask user if they want to save .pkl
    4) *** NEW ***: a big title label showing gamma_ee, gamma_ei, gamma_ie, gamma_ii (+S,R if top_down)
    """
    def __init__(self, parent=None):
        super().__init__(parent)
        self.setWindowTitle("Window3 - Execution")
        self.resize(800, 600)

        self.params = None
        self.top_down = False
        self.S = 0.3
        self.R = 1.0
        self.num_trial = 10
        self.threshold = 60
        self.trial_len = 2000
        self.BG_len = 300
        self.input_amp = 0.00156
        self.coherence_list = [0, 3.2, 6.4, 12.8, 25.6, 51.2]

        self.result_dict = {}

        self.init_ui()

    def init_ui(self):
        main_layout = QVBoxLayout()

        self.label_title = QLabel("No Title Yet")
        self.label_title.setStyleSheet("font-size: 16pt; font-weight: bold;")
        main_layout.addWidget(self.label_title)

        h_layout = QHBoxLayout()
        self.label_info = QLabel("Ready.")
        self.progress_bar = QProgressBar()
        self.progress_bar.setValue(0)
        h_layout.addWidget(self.label_info)
        h_layout.addWidget(self.progress_bar)
        main_layout.addLayout(h_layout)

        self.btn_run = QPushButton("Run")
        self.btn_run.clicked.connect(self.on_run_clicked)
        main_layout.addWidget(self.btn_run)

        self.plot_widget = MatplotlibWidget(self)
        main_layout.addWidget(self.plot_widget)

        self.setLayout(main_layout)

    def set_parameters(self, merged_dict, topdown_dict):
        """
        When this dialog is opened, set the parameters from the previous dialog.
        """
        self.top_down = topdown_dict['top_down']
        self.S = topdown_dict['S']
        self.R = topdown_dict['R']

        self.params = merged_dict['params']
        self.params['top_down'] = self.top_down
        self.params['S'] = self.S
        self.params['R'] = self.R

        self.num_trial = merged_dict['others']['num_trial']
        self.threshold = merged_dict['others']['threshold']
        self.trial_len = merged_dict['others']['trial_len']
        self.BG_len    = merged_dict['others']['BG_len']
        self.input_amp = merged_dict['others']['input_amp']

        gamma_ee = self.params.get('gamma_ee', 0.0)
        gamma_ei = self.params.get('gamma_ei', 0.0)
        gamma_ie = self.params.get('gamma_ie', 0.0)
        gamma_ii = self.params.get('gamma_ii', 0.0)

        title_str = f"gamma_ee={gamma_ee}, gamma_ei={gamma_ei}, gamma_ie={gamma_ie}, gamma_ii={gamma_ii}"
        if self.top_down:
            title_str += f", S={self.S}, R={self.R}"

        self.label_title.setText(title_str)

        self.label_info.setText("Parameters loaded. Click 'Run' to start.")

    def on_run_clicked(self):
        performance_array = []
        rt_mean_array = []
        rt_std_array  = []
        self.result_dict = {}

        for i, coh in enumerate(self.coherence_list):
            self.progress_bar.setRange(0, self.num_trial)
            self.progress_bar.setValue(0)

            self.label_info.setText(f"Running Coherence {i+1}/{len(self.coherence_list)} = {coh}")
            self.label_info.repaint()
            QApplication.processEvents()

            ER_count = 0
            EL_count = 0
            ER_RT_list = []
            coherence_result_dict = {}

            for t in range(self.num_trial):
                self.label_info.setText(f"coh={coh} | trial {t+1}/{self.num_trial}")
                self.label_info.repaint()
                QApplication.processEvents()

                (result_list,
                 total_EL_RT,
                 total_ER_RT,
                 no_decision,
                 EL_firing,
                 ER_firing) = DataGeneration(
                    params=self.params,
                    num_trial=1,  # 每次只跑1 trial
                    coherence=coh,
                    threshold=self.threshold,
                    trial_len=self.trial_len,
                    BG_len=self.BG_len,
                    input_amp=self.input_amp
                )

                trial_data = result_list.get(0, None)
                coherence_result_dict[t] = trial_data

                if trial_data is not None:
                    ER_flag = trial_data[0]
                    EL_flag = trial_data[1]
                    ER_RT   = trial_data[2]
                    EL_RT   = trial_data[3]

                    ER_count += ER_flag
                    EL_count += EL_flag
                    if ER_flag == 1 and ER_RT is not None:
                        ER_RT_list.append(ER_RT)

                self.progress_bar.setValue(t+1)
                QApplication.processEvents()
                time.sleep(0.01)

            self.result_dict[f"coh{coh}"] = coherence_result_dict

            total_decisions = ER_count + EL_count
            if total_decisions == 0:
                perf = 0.0
            else:
                perf = ER_count / total_decisions

            if len(ER_RT_list) > 0:
                mean_rt = float(np.mean(ER_RT_list))
                std_rt  = float(np.std(ER_RT_list))
            else:
                mean_rt = 0.0
                std_rt  = 0.0

            performance_array.append(perf)
            rt_mean_array.append(mean_rt)
            rt_std_array.append(std_rt)

        self.label_info.setText("Finished. Plotting results...")
        self.plot_and_show(performance_array, rt_mean_array, rt_std_array)
        self.label_info.setText("Plot done. Asking for save...")
        self.ask_to_save_file()

    def plot_and_show(self, performance_array, rt_mean_array, rt_std_array):
        fig = self.plot_widget.get_figure()
        fig.clear()
        ax = fig.add_subplot(111)
        ax2 = ax.twinx()

        # Performance
        ax.plot(self.coherence_list, performance_array, 'bo-', label="Performance")
        ax.set_ylabel("Performance (fraction E_R)")
        ax.set_xlabel("Coherence (%)")
        ax.set_ylim(0, 1.05)

        # RT
        ax2.errorbar(
            self.coherence_list, 
            rt_mean_array, 
            yerr=rt_std_array, 
            fmt='rs--', 
            capsize=5,
            label="Reaction Time"
        )
        ax2.set_ylabel("Reaction Time (ms)")
        max_rt = max(rt_mean_array) if rt_mean_array else 100
        min_rt = min(rt_mean_array) if rt_mean_array else 0
        ax2.set_ylim(min_rt-10, max_rt+10)

        ax.grid(True)
        ax.set_title("Performance & RT vs Coherence")
        ax.legend(loc='upper left')
        ax2.legend(loc='upper right')
        self.plot_widget.canvas.draw()

    def ask_to_save_file(self):
        reply = QMessageBox.question(
            self,
            "Save Data?",
            "Do you want to save the result data to a .pkl file?",
            QMessageBox.Yes | QMessageBox.No,
            QMessageBox.No
        )
        if reply == QMessageBox.Yes:
            file_path, _ = QFileDialog.getSaveFileName(
                self,
                "Save As",
                "",
                "Pickle Files (*.pkl);;All Files (*)"
            )
            if file_path:
                with open(file_path, 'wb') as f:
                    pickle.dump(self.result_dict, f)
                self.label_info.setText(f"Data saved to: {file_path}")
            else:
                self.label_info.setText("User canceled save dialog.")
        else:
            self.label_info.setText("User chose not to save.")

#######################
# Main Window
#######################
class MainWindow(QMainWindow):
    def __init__(self):
        super().__init__()
        self.setWindowTitle("MainWindow (for demonstration)")
        self.resize(300, 100)

        btn_start = QPushButton("Start Setting Parameters")
        btn_start.clicked.connect(self.start_sequence)

        central_widget = QWidget()
        vlayout = QVBoxLayout()
        vlayout.addWidget(btn_start)
        central_widget.setLayout(vlayout)
        self.setCentralWidget(central_widget)

        self.topdown_params = {}
        self.merged_params = {}

    def start_sequence(self):
        # Window1
        dlg1 = TopDownDialog(self)
        dlg1.submitted.connect(self.on_topdown_submitted)
        dlg1.exec_()

    def on_topdown_submitted(self, data_dict):
        self.topdown_params = data_dict
        # Window2
        dlg2 = ParamDialog(self)
        dlg2.submitted.connect(self.on_param_submitted)
        dlg2.exec_()

    def on_param_submitted(self, data_dict):
        self.merged_params = data_dict
        # Window3
        dlg3 = ExecDialog(self)
        dlg3.set_parameters(self.merged_params, self.topdown_params)
        dlg3.exec_()


if __name__ == "__main__":
    app = QApplication(sys.argv)
    win = MainWindow()
    win.show()
    sys.exit(app.exec_())



KeyboardInterrupt: 