<a href="https://colab.research.google.com/github/arikalamonisha/miniproj/blob/main/Untitled6.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
print(torch.__version__)
print(torch.cuda.is_available())


2.6.0+cu124
False


In [10]:
import torch
import torch.nn as nn
import torch.optim as optim
import copy
import random

# --- Multi-Modal Model ---
class MultiModalModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.image_branch = nn.Sequential(
            nn.Conv2d(1, 16, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(16, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Flatten()
        )
        self.tabular_branch = nn.Sequential(
            nn.Linear(10, 32), nn.ReLU(),  # Assume 10 EHR features
            nn.Linear(32, 16)
        )
        self.classifier = nn.Sequential(
            nn.Linear(32*8*8 + 16, 64), nn.ReLU(),  # Adjust if needed
            nn.Linear(64, 2)  # Binary classification: COVID vs. Non-COVID
        )

    def forward(self, image, tabular):
        x1 = self.image_branch(image)
        x2 = self.tabular_branch(tabular)
        x = torch.cat((x1, x2), dim=1)
        return self.classifier(x)

# --- Simulate Client Training ---
def train_client(model, data_loader, epochs=1):
    model = copy.deepcopy(model)
    model.train()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    for _ in range(epochs):
        for image, tabular, label in data_loader:
            optimizer.zero_grad()
            output = model(image, tabular)
            loss = criterion(output, label)
            loss.backward()
            optimizer.step()
    return model.state_dict()

# --- Federated Averaging ---
def federated_avg(state_dicts):
    avg_dict = copy.deepcopy(state_dicts[0])
    for key in avg_dict.keys():
        for i in range(1, len(state_dicts)):
            avg_dict[key] += state_dicts[i][key]
        avg_dict[key] = torch.div(avg_dict[key], len(state_dicts))
    return avg_dict

# --- Simulation Setup ---
def simulate_federated_learning(global_model, client_loaders, rounds=5):
    for r in range(rounds):
        print(f"\n--- Round {r+1} ---")
        local_weights = []
        for client_id, loader in enumerate(client_loaders):
            print(f"Training on Client {client_id}")
            local_model = train_client(global_model, loader)
            local_weights.append(local_model)
        avg_weights = federated_avg(local_weights)
        global_model.load_state_dict(avg_weights)
    return global_model
