In [None]:
!pip install torch torchvision
!pip install medmnist



In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import medmnist
from medmnist import INFO
from torch.utils.data import DataLoader, Subset
from torchvision import transforms

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Choose the BreastMNIST dataset
data_flag = 'breastmnist'
info = INFO[data_flag]
n_channels = info['n_channels']
n_classes = 2  # Binary classification for BreastMNIST

# Define transformations
transform = transforms.Compose([
    transforms.ToTensor(),  # Convert images to PyTorch tensors
    transforms.Normalize(mean=[0.5], std=[0.5])
])

# Load the dataset with transforms
DataClass = getattr(medmnist, info['python_class'])
train_data = DataClass(split='train', download=True, transform=transform)
test_data = DataClass(split='test', download=True, transform=transform)

# Determine the size of each split
split_size = len(train_data) // 6

# Create indices for splitting the data
indices = torch.randperm(len(train_data)).tolist()

# Split data into 6 parts
client_datasets = [Subset(train_data, indices[i*split_size : (i+1)*split_size]) for i in range(6)]

# Create data loaders for each client
batch_size = 32  # You can adjust the batch size
client_loaders = [DataLoader(dataset, batch_size=batch_size, shuffle=True) for dataset in client_datasets]
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

# Define Client Model
class ClientModel(nn.Module):
    def __init__(self):
        super(ClientModel, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(n_channels, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(7 * 7 * 32, 128)
        )

    def forward(self, x):
        return self.model(x)

# Define Server Model
class ServerModel(nn.Module):
    def __init__(self):
        super(ServerModel, self).__init__()
        # 768 features from the concatenated client outputs
        self.fc1 = nn.Linear(768, 128)  # Intermediate layer
        self.fc2 = nn.Linear(128, n_classes)  # Final layer for classification

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        return torch.log_softmax(self.fc2(x), dim=1)


# Initialize client and server models
clients = [ClientModel().to(device) for _ in range(6)]
server = ServerModel().to(device)

# Define optimizers
client_optimizers = [optim.SGD(client.parameters(), lr=0.01) for client in clients]
server_optimizer = optim.SGD(server.parameters(), lr=0.01)

# Training function
def train(epoch):
    server.train()
    for client, optimizer, loader in zip(clients, client_optimizers, client_loaders):
        client.train()
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            target = target.squeeze()

            optimizer.zero_grad()
            client_outputs = [client(data) for client in clients]
            aggregated_output = torch.cat(client_outputs, dim=1)
            server_output = server(aggregated_output)
            loss = nn.functional.nll_loss(server_output, target)
            loss.backward()
            optimizer.step()

    print(f"Epoch {epoch}: Loss: {loss.item()}")

# Testing function
def test():
    server.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)

            # Ensure target is 1D: target might be [batch_size, 1], so we squeeze it
            target = target.squeeze()

            # Concatenate client outputs
            client_outputs = [client(data) for client in clients]
            aggregated_output = torch.cat(client_outputs, dim=1)
            server_output = server(aggregated_output)

            # Calculate loss
            test_loss += nn.functional.nll_loss(server_output, target, reduction='sum').item()

            # Calculate accuracy
            pred = server_output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({100. * correct / len(test_loader.dataset):.0f}%)\n')


# Run training and testing
for epoch in range(1, 50):  # Adjust number of epochs as needed
    train(epoch)
test()


Using downloaded and verified file: /root/.medmnist/breastmnist.npz
Using downloaded and verified file: /root/.medmnist/breastmnist.npz
Epoch 1: Loss: 0.65912926197052
Epoch 2: Loss: 0.687837541103363
Epoch 3: Loss: 0.6464515924453735
Epoch 4: Loss: 0.652210533618927
Epoch 5: Loss: 0.661811113357544
Epoch 6: Loss: 0.6509944796562195
Epoch 7: Loss: 0.6878275871276855
Epoch 8: Loss: 0.6609188318252563
Epoch 9: Loss: 0.6327899694442749
Epoch 10: Loss: 0.6478314399719238
Epoch 11: Loss: 0.6672062873840332
Epoch 12: Loss: 0.6783787608146667
Epoch 13: Loss: 0.629980742931366
Epoch 14: Loss: 0.6407749056816101
Epoch 15: Loss: 0.7184495329856873
Epoch 16: Loss: 0.6506174802780151
Epoch 17: Loss: 0.6904046535491943
Epoch 18: Loss: 0.6477341055870056
Epoch 19: Loss: 0.6294332146644592
Epoch 20: Loss: 0.6760017275810242
Epoch 21: Loss: 0.6135280132293701
Epoch 22: Loss: 0.6153355836868286
Epoch 23: Loss: 0.6814242005348206
Epoch 24: Loss: 0.6244473457336426
Epoch 25: Loss: 0.611919105052948
Epoch