In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import socket
from torch.utils.data import DataLoader, TensorDataset
from tkinter import Tk, Button, Label, messagebox
import pandas as pd
import pickle
import threading

# Define the client's model
class ECGNN(nn.Module):
    def __init__(self):
        super(ECGNN, self).__init__()
        self.fc1 = nn.Linear(140, 128)  # Input size matches the number of ECG data points (140)
        self.fc2 = nn.Linear(128, 1)    # Binary output (normal or abnormal)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.sigmoid(self.fc2(x))  # Sigmoid for binary classification
        return x

# Local training function
def train_local(model, data_loader, epochs=1, lr=0.01):
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.SGD(model.parameters(), lr=lr)

    model.train()
    for epoch in range(epochs):
        for data, target in data_loader:
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

# Function to download the global model from the server
def get_global_model(client_socket):
    try:
        client_socket.sendall(b"REQUEST_GLOBAL_MODEL")
        model_data = b''
        while True:
            part = client_socket.recv(4096)
            model_data += part
            if len(part) < 4096:
                break
        global_model_state = pickle.loads(model_data)
        return global_model_state
    except Exception as e:
        messagebox.showerror("Error", f"Failed to download global model: {e}")
        return None

# Function to send local updates to the server
def send_model_update(client_socket, model):
    try:
        model_state = model.state_dict()
        model_data = pickle.dumps(model_state)
        client_socket.sendall(b"SEND_MODEL_UPDATE")
        client_socket.sendall(model_data)
        messagebox.showinfo("Success", "Model sent to server successfully.")
    except Exception as e:
        messagebox.showerror("Error", f"Failed to send model to server: {e}")

# Download global model, train locally, and send the update to the server
def download_train_send(client_socket):
    try:
        # Initialize the client's model
        client_model = ECGNN()

        # Get the global model from the server
        global_model_state = get_global_model(client_socket)
        if global_model_state is None:
            return
        client_model.load_state_dict(global_model_state)
        messagebox.showinfo("Success", "Global model downloaded successfully.")

        # Load local data and train locally
        df = pd.read_csv("ecg_local.csv")
        x = torch.tensor(df.iloc[:, :-1].values, dtype=torch.float32)   # ECG data points (features)
        y = torch.tensor(df.iloc[:, -1].values, dtype=torch.float32).unsqueeze(1)  # Labels
        
        dataset = TensorDataset(x, y)
        data_loader = DataLoader(dataset, batch_size=32, shuffle=True)
        
        train_local(client_model, data_loader)
        messagebox.showinfo("Success", "Model trained locally.")
        
        # Send the local model update to the server
        send_model_update(client_socket, client_model)
        
    except Exception as e:
        messagebox.showerror("Error", f"Failed to complete the operation: {e}")

# GUI for triggering the download, training, and sending process
def client_gui(client_socket):
    root = Tk()
    root.title("Federated Learning Client")

    label = Label(root, text="Federated Learning Client", font=("Arial", 16))
    label.pack(pady=20)

    start_button = Button(root, text="Start Training", command=lambda: download_train_send(client_socket))
    start_button.pack(pady=20)

    root.mainloop()

# Client connection setup
def client_connect():
    client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    client_socket.connect(('LAPTOP-4757NS4C', 5002))  # Replace with your server's address
    client_gui(client_socket)
    client_socket.close()

if __name__ == '__main__':
    client_connect()
