In [None]:
import socket
import threading
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import pandas as pd
from tkinter import Tk, Button, Label, StringVar, OptionMenu, messagebox, filedialog
import pickle

# Define a neural network model compatible with ECG data
class ECGNN(nn.Module):
    def __init__(self):
        super(ECGNN, self).__init__()
        self.fc1 = nn.Linear(140, 128)  # Input size matches the number of ECG data points (140)
        self.fc2 = nn.Linear(128, 1)    # Binary output (normal or abnormal)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.sigmoid(self.fc2(x))  # Sigmoid for binary classification
        return x

# Initialize the global model
global_model = ECGNN()

# Placeholder for client IPs
clients = []
server_loader = None
root = None  # Global reference to the Tkinter root window
client_var = None
client_menu = None

# Load data from CSV file
def load_data_from_csv(file_path):
    df = pd.read_csv(file_path)
    X = df.iloc[:, :-1].values  # Features
    y = df.iloc[:, -1].values   # Labels
    
    X_tensor = torch.tensor(X, dtype=torch.float32)
    y_tensor = torch.tensor(y, dtype=torch.float32).unsqueeze(1)  # Ensure labels have the correct shape
    
    dataset = TensorDataset(X_tensor, y_tensor)
    return DataLoader(dataset, batch_size=16, shuffle=True)

# Train the global model using server-side dataset
def train_global_model():
    global global_model
    criterion = nn.BCELoss()  # Binary Cross Entropy for binary classification
    optimizer = optim.SGD(global_model.parameters(), lr=0.01)

    global_model.train()
    for epoch in range(5):  # Train for 5 epochs
        for batch_idx, (data, target) in enumerate(server_loader):
            optimizer.zero_grad()
            output = global_model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
        print(f"Global Model Training - Epoch {epoch+1}, Loss: {loss.item()}")

    print("Global model training completed.")

# Aggregate client updates
def aggregate(global_model, client_models):
    global_dict = global_model.state_dict()
    for k in global_dict.keys():
        global_dict[k] = torch.mean(torch.stack([client_models[i][k] for i in range(len(client_models))]), dim=0)
    global_model.load_state_dict(global_dict)
    train_global_model()

# Function to handle CSV file upload and training
def upload_and_train():
    global server_loader
    file_path = filedialog.askopenfilename(filetypes=[("CSV files", "*.csv")])
    if file_path:
        server_loader = load_data_from_csv(file_path)
        train_global_model()
        messagebox.showinfo("Training Complete", "Global model has been trained with the uploaded data.")

# Notify via GUI when a client connects
def notify_client_connected(client_ip):
    if root:
        update_client_list()  # Update the client list in the dropdown

# Update client list in GUI dropdown
def update_client_list():
    client_var.set("No clients connected" if not clients else clients[0])
    client_menu['menu'].delete(0, 'end')
    for client in clients:
        client_menu['menu'].add_command(label=client, command=lambda value=client: client_var.set(value))

# Set up the GUI
def setup_gui():
    global root, client_var, client_menu, clients
    root = Tk()
    root.title("Federated Learning with ECG Data")

    label = Label(root, text="Federated Learning Interface", font=("Arial", 16))
    label.pack(pady=20)

    client_var = StringVar(root)
    client_var.set("No clients connected")

    client_menu = OptionMenu(root, client_var, clients)
    client_menu.pack(pady=20)

    upload_button = Button(root, text="Upload CSV & Train Model", command=upload_and_train)
    upload_button.pack(pady=20)

    root.mainloop()

# Handle incoming client connections
def handle_client_connection(client_socket, client_address):
    client_ip = client_address[0]
    if client_ip not in clients:
        clients.append(client_ip)
        notify_client_connected(client_ip)
    print(f"Client {client_ip} connected.")

    try:
        while True:
            data = client_socket.recv(4096)
            if not data:
                break
            
            if data == b"REQUEST_GLOBAL_MODEL":
                # Send the global model to the client
                model_data = pickle.dumps(global_model.state_dict())
                client_socket.sendall(model_data)

            elif data.startswith(b"SEND_MODEL_UPDATE"):
                # Receive the model from the client
                model_data = b""
                while True:
                    packet = client_socket.recv(4096)
                    if not packet:
                        break
                    model_data += packet
                
                client_model_state = pickle.loads(model_data)
                aggregate(global_model, [client_model_state])
                break  # Break the loop after receiving the model

    except socket.error as e:
        print(f"Error with client {client_ip}: {e}")
    
    finally:
        client_socket.close()
        print(f"Client {client_ip} disconnected.")

# Run the server to accept connections
def run_server():
    server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    server_socket.bind(('0.0.0.0', 5002))
    server_socket.listen(5)
    print("Server is listening for connections...")

    while True:
        client_socket, client_address = server_socket.accept()
        handle_client_connection(client_socket, client_address)

# Run the server and GUI concurrently
def main():
    # Start the server in a separate thread
    server_thread = threading.Thread(target=run_server, daemon=True)
    server_thread.start()

    # Start the GUI
    setup_gui()

if __name__ == '__main__':
    main()
