<a href="https://colab.research.google.com/github/OneFineStarstuff/OneFineStarstuff/blob/main/Using_PyTorch_to_simulate_Federated_Learning_Model_in_Healthcare.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

# Define a simple neural network
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(10, 50)
        self.fc2 = nn.Linear(50, 1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.sigmoid(self.fc2(x))
        return x

# Function to fetch simulated hospital data
def get_encrypted_hospital_data():
    data = torch.randn(100, 10)  # 100 samples, 10 features each
    targets = torch.randint(0, 2, (100,))  # Binary targets for classification
    return data, targets

# Function to split data among workers
def split_data(data, targets, n_workers):
    data_split = torch.chunk(data, n_workers)
    targets_split = torch.chunk(targets, n_workers)
    return list(zip(data_split, targets_split))

# Federated learning function
def train_federated_model(data_splits, model):
    criterion = nn.BCELoss()
    optimizer = optim.SGD(model.parameters(), lr=0.01)

    for epoch in range(10):  # Train for 10 epochs
        model.train()
        total_loss = 0
        for data, targets in data_splits:
            optimizer.zero_grad()
            outputs = model(data).view(-1)
            loss = criterion(outputs, targets.float())
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"Epoch {epoch + 1}, Loss: {total_loss / len(data_splits)}")

    return model

# Simulated hospital data
data, targets = get_encrypted_hospital_data()
data_splits = split_data(data, targets, n_workers=2)  # Splitting data among 2 workers

# Initialize the model
model = SimpleNN()

# Train the model using federated learning
trained_model = train_federated_model(data_splits, model)
print("Federated Model Trained for Healthcare Insights")