In [None]:
!pip install torch torchvision
!pip install medmnist


Collecting medmnist
  Downloading medmnist-3.0.1-py3-none-any.whl (25 kB)
Collecting fire (from medmnist)
  Downloading fire-0.5.0.tar.gz (88 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m88.3/88.3 kB[0m [31m564.4 kB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: fire
  Building wheel for fire (setup.py) ... [?25l[?25hdone
  Created wheel for fire: filename=fire-0.5.0-py2.py3-none-any.whl size=116934 sha256=d06af29a5b84847a6abc4e55114b57261ecc78284d0a94b9621b1f0020f55dde
  Stored in directory: /root/.cache/pip/wheels/90/d4/f7/9404e5db0116bd4d43e5666eaa3e70ab53723e1e3ea40c9a95
Successfully built fire
Installing collected packages: fire, medmnist
Successfully installed fire-0.5.0 medmnist-3.0.1


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import medmnist
from medmnist import INFO
from torch.utils.data import DataLoader, Subset
from torchvision import transforms

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Choose the BreastMNIST dataset
data_flag = 'breastmnist'
info = INFO[data_flag]
n_channels = info['n_channels']
n_classes = 2  # Binary classification for BreastMNIST

# Define transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

# Load the dataset with transforms
DataClass = getattr(medmnist, info['python_class'])
train_data = DataClass(split='train', download=True, transform=transform)
test_data = DataClass(split='test', download=True, transform=transform)

# Split data into 2 parts for clients
split_size = len(train_data) // 2
client_datasets = [Subset(train_data, range(i * split_size, (i + 1) * split_size)) for i in range(2)]

# Create data loaders for each client
batch_size = 32
client_loaders = [DataLoader(dataset, batch_size=batch_size, shuffle=True) for dataset in client_datasets]
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

# Client Model
class ClientModel(nn.Module):
    def __init__(self):
        super(ClientModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.fc = nn.Linear(7 * 7 * 32, 1568)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

# Hypernetwork
class Hypernetwork(nn.Module):
    def __init__(self):
        super(Hypernetwork, self).__init__()
        self.fc = nn.Linear(2 * 1568, 3136)

    def forward(self, x):
        return self.fc(x)

# Server Model
class ServerModel(nn.Module):
    def __init__(self):
        super(ServerModel, self).__init__()
        self.fc1 = nn.Linear(3136, 128)
        self.fc2 = nn.Linear(128, n_classes)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        return F.log_softmax(self.fc2(x), dim=1)

# Initialize models and move to device
clients = [ClientModel().to(device) for _ in range(2)]
hypernetwork = Hypernetwork().to(device)
server = ServerModel().to(device)

# Define optimizers
client_optimizers = [optim.SGD(client.parameters(), lr=0.01) for client in clients]
hypernetwork_optimizer = optim.SGD(hypernetwork.parameters(), lr=0.01)
server_optimizer = optim.SGD(server.parameters(), lr=0.01)

# Training function
def train(epoch):
    server.train()
    hypernetwork.train()
    for client, optimizer in zip(clients, client_optimizers):
        client.train()
        for data, target in client_loaders[clients.index(client)]:
            data, target = data.to(device), target.to(device)
            target = target.squeeze()

            optimizer.zero_grad()
            client_output = client(data)

            # Concatenate client outputs for the hypernetwork
            aggregated_output = torch.cat([client_output for client in clients], dim=1)

            hypernetwork_output = hypernetwork(aggregated_output)
            server_output = server(hypernetwork_output)
            loss = F.nll_loss(server_output, target)
            loss.backward()
            optimizer.step()

    print(f"Epoch {epoch}: Loss: {loss.item()}")

# def train(epoch):
#     server.train()
#     hypernetwork.train()
#     for client, optimizer in zip(clients, client_optimizers):
#         client.train()
#         for data, target in client_loaders[clients.index(client)]:
#             data, target = data.to(device), target.to(device)
#             target = target.squeeze()

#             optimizer.zero_grad()
#             client_output = client(data)
#             hypernetwork_output = hypernetwork(client_output)
#             server_output = server(hypernetwork_output)
#             loss = F.nll_loss(server_output, target)
#             loss.backward()
#             optimizer.step()

#     print(f"Epoch {epoch}: Loss: {loss.item()}")

# Testing function
def test():
    server.eval()
    hypernetwork.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            target = target.squeeze()

            client_outputs = [client(data) for client in clients]
            aggregated_output = torch.cat(client_outputs, dim=1)
            hypernetwork_output = hypernetwork(aggregated_output)
            server_output = server(hypernetwork_output)
            test_loss += F.nll_loss(server_output, target, reduction='sum').item()
            pred = server_output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({100. * correct / len(test_loader.dataset):.0f}%)\n')

# Run training and testing
for epoch in range(1, 50):
    train(epoch)
test()


Downloading https://zenodo.org/records/10519652/files/breastmnist.npz?download=1 to /root/.medmnist/breastmnist.npz


100%|██████████| 559580/559580 [00:00<00:00, 666185.45it/s]


Using downloaded and verified file: /root/.medmnist/breastmnist.npz
Epoch 1: Loss: 0.6883583664894104
Epoch 2: Loss: 0.671349287033081
Epoch 3: Loss: 0.6668608784675598
Epoch 4: Loss: 0.6677567958831787
Epoch 5: Loss: 0.6654928922653198
Epoch 6: Loss: 0.6452265381813049
Epoch 7: Loss: 0.6534948945045471
Epoch 8: Loss: 0.6489424705505371
Epoch 9: Loss: 0.6673902869224548
Epoch 10: Loss: 0.7092739939689636
Epoch 11: Loss: 0.7181890606880188
Epoch 12: Loss: 0.63253253698349
Epoch 13: Loss: 0.5538967251777649
Epoch 14: Loss: 0.7746447324752808
Epoch 15: Loss: 0.5327228903770447
Epoch 16: Loss: 0.6791759133338928
Epoch 17: Loss: 0.5765900611877441
Epoch 18: Loss: 0.5467226505279541
Epoch 19: Loss: 0.5927532911300659
Epoch 20: Loss: 0.496958464384079
Epoch 21: Loss: 0.7370741963386536
Epoch 22: Loss: 0.457574725151062
Epoch 23: Loss: 0.5658912062644958
Epoch 24: Loss: 0.6567019820213318
Epoch 25: Loss: 0.5558581352233887
Epoch 26: Loss: 0.6098633408546448
Epoch 27: Loss: 0.4762454926967621
E

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import medmnist
from medmnist import INFO
from torch.utils.data import DataLoader, Subset
from torchvision import transforms

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Choose the BreastMNIST dataset
data_flag = 'breastmnist'
info = INFO[data_flag]
n_channels = info['n_channels']
n_classes = 2  # Binary classification for BreastMNIST

# Define transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

# Load the dataset with transforms
DataClass = getattr(medmnist, info['python_class'])
train_data = DataClass(split='train', download=True, transform=transform)
test_data = DataClass(split='test', download=True, transform=transform)

# Split data into 2 parts for clients
split_size = len(train_data) // 2
client_datasets = [Subset(train_data, range(i * split_size, (i + 1) * split_size)) for i in range(2)]

# Create data loaders for each client
batch_size = 32
client_loaders = [DataLoader(dataset, batch_size=batch_size, shuffle=True) for dataset in client_datasets]
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

# Client Model
class ClientModel(nn.Module):
    def __init__(self):
        super(ClientModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.fc = nn.Linear(7 * 7 * 32, 1568)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

# Hypernetwork
class Hypernetwork(nn.Module):
    def __init__(self):
        super(Hypernetwork, self).__init__()
        self.fc = nn.Linear(2 * 1568, 6272)

    def forward(self, x):
        return self.fc(x)

# Server Model
class ServerModel(nn.Module):
    def __init__(self):
        super(ServerModel, self).__init__()
        self.fc1 = nn.Linear(6272, 128)
        self.fc2 = nn.Linear(128, n_classes)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        return F.log_softmax(self.fc2(x), dim=1)

# Initialize models and move to device
clients = [ClientModel().to(device) for _ in range(2)]
hypernetwork = Hypernetwork().to(device)
server = ServerModel().to(device)

# Define optimizers
client_optimizers = [optim.SGD(client.parameters(), lr=0.01) for client in clients]
hypernetwork_optimizer = optim.SGD(hypernetwork.parameters(), lr=0.01)
server_optimizer = optim.SGD(server.parameters(), lr=0.01)

# Training function
def train(epoch):
    server.train()
    hypernetwork.train()
    for client, optimizer in zip(clients, client_optimizers):
        client.train()
        for data, target in client_loaders[clients.index(client)]:
            data, target = data.to(device), target.to(device)
            target = target.squeeze()

            optimizer.zero_grad()
            client_output = client(data)

            # Concatenate client outputs for the hypernetwork
            aggregated_output = torch.cat([client_output for client in clients], dim=1)

            hypernetwork_output = hypernetwork(aggregated_output)
            server_output = server(hypernetwork_output)
            loss = F.nll_loss(server_output, target)
            loss.backward()
            optimizer.step()

    print(f"Epoch {epoch}: Loss: {loss.item()}")

# def train(epoch):
#     server.train()
#     hypernetwork.train()
#     for client, optimizer in zip(clients, client_optimizers):
#         client.train()
#         for data, target in client_loaders[clients.index(client)]:
#             data, target = data.to(device), target.to(device)
#             target = target.squeeze()

#             optimizer.zero_grad()
#             client_output = client(data)
#             hypernetwork_output = hypernetwork(client_output)
#             server_output = server(hypernetwork_output)
#             loss = F.nll_loss(server_output, target)
#             loss.backward()
#             optimizer.step()

#     print(f"Epoch {epoch}: Loss: {loss.item()}")

# Testing function
def test():
    server.eval()
    hypernetwork.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            target = target.squeeze()

            client_outputs = [client(data) for client in clients]
            aggregated_output = torch.cat(client_outputs, dim=1)
            hypernetwork_output = hypernetwork(aggregated_output)
            server_output = server(hypernetwork_output)
            test_loss += F.nll_loss(server_output, target, reduction='sum').item()
            pred = server_output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({100. * correct / len(test_loader.dataset):.0f}%)\n')

# Run training and testing
for epoch in range(1, 100):
    train(epoch)
test()


Using downloaded and verified file: /root/.medmnist/breastmnist.npz
Using downloaded and verified file: /root/.medmnist/breastmnist.npz
Epoch 1: Loss: 0.6768783926963806
Epoch 2: Loss: 0.6754679679870605
Epoch 3: Loss: 0.6742483377456665
Epoch 4: Loss: 0.6575452089309692
Epoch 5: Loss: 0.6548781394958496
Epoch 6: Loss: 0.66746985912323
Epoch 7: Loss: 0.6751784682273865
Epoch 8: Loss: 0.6828363537788391
Epoch 9: Loss: 0.6424790024757385
Epoch 10: Loss: 0.6728195548057556
Epoch 11: Loss: 0.6956724524497986
Epoch 12: Loss: 0.6534914970397949
Epoch 13: Loss: 0.6510999202728271
Epoch 14: Loss: 0.6396979093551636
Epoch 15: Loss: 0.679373562335968
Epoch 16: Loss: 0.6588160991668701
Epoch 17: Loss: 0.6643362641334534
Epoch 18: Loss: 0.6189576387405396
Epoch 19: Loss: 0.6766189336776733
Epoch 20: Loss: 0.63462233543396
Epoch 21: Loss: 0.65126633644104
Epoch 22: Loss: 0.6316800117492676
Epoch 23: Loss: 0.6482616066932678
Epoch 24: Loss: 0.6430801153182983
Epoch 25: Loss: 0.6559993028640747
Epoch

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import medmnist
from medmnist import INFO
from torch.utils.data import DataLoader, Subset
from torchvision import transforms

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Choose the BreastMNIST dataset
data_flag = 'breastmnist'
info = INFO[data_flag]
n_channels = info['n_channels']
n_classes = 2  # Binary classification for BreastMNIST

# Define transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

# Load the dataset with transforms
DataClass = getattr(medmnist, info['python_class'])
train_data = DataClass(split='train', download=True, transform=transform)
test_data = DataClass(split='test', download=True, transform=transform)

# Split data into 2 parts for clients
split_size = len(train_data) // 2
client_datasets = [Subset(train_data, range(i * split_size, (i + 1) * split_size)) for i in range(2)]

# Create data loaders for each client
batch_size = 32
client_loaders = [DataLoader(dataset, batch_size=batch_size, shuffle=True) for dataset in client_datasets]
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

# Client Model
class ClientModel(nn.Module):
    def __init__(self):
        super(ClientModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.fc = nn.Linear(7 * 7 * 32, 1568)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

# Hypernetwork
class Hypernetwork(nn.Module):
    def __init__(self):
        super(Hypernetwork, self).__init__()
        self.fc = nn.Linear(2 * 1568, 9408)

    def forward(self, x):
        return self.fc(x)

# Server Model
class ServerModel(nn.Module):
    def __init__(self):
        super(ServerModel, self).__init__()
        self.fc1 = nn.Linear(9408,128)
        self.fc2 = nn.Linear(128, n_classes)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        return F.log_softmax(self.fc2(x), dim=1)

# Initialize models and move to device
clients = [ClientModel().to(device) for _ in range(2)]
hypernetwork = Hypernetwork().to(device)
server = ServerModel().to(device)

# Define optimizers
client_optimizers = [optim.SGD(client.parameters(), lr=0.01) for client in clients]
hypernetwork_optimizer = optim.SGD(hypernetwork.parameters(), lr=0.01)
server_optimizer = optim.SGD(server.parameters(), lr=0.01)

# Training function
def train(epoch):
    server.train()
    hypernetwork.train()
    for client, optimizer in zip(clients, client_optimizers):
        client.train()
        for data, target in client_loaders[clients.index(client)]:
            data, target = data.to(device), target.to(device)
            target = target.squeeze()

            optimizer.zero_grad()
            client_output = client(data)

            # Concatenate client outputs for the hypernetwork
            aggregated_output = torch.cat([client_output for client in clients], dim=1)

            hypernetwork_output = hypernetwork(aggregated_output)
            server_output = server(hypernetwork_output)
            loss = F.nll_loss(server_output, target)
            loss.backward()
            optimizer.step()

    print(f"Epoch {epoch}: Loss: {loss.item()}")

# def train(epoch):
#     server.train()
#     hypernetwork.train()
#     for client, optimizer in zip(clients, client_optimizers):
#         client.train()
#         for data, target in client_loaders[clients.index(client)]:
#             data, target = data.to(device), target.to(device)
#             target = target.squeeze()

#             optimizer.zero_grad()
#             client_output = client(data)
#             hypernetwork_output = hypernetwork(client_output)
#             server_output = server(hypernetwork_output)
#             loss = F.nll_loss(server_output, target)
#             loss.backward()
#             optimizer.step()

#     print(f"Epoch {epoch}: Loss: {loss.item()}")

# Testing function
def test():
    server.eval()
    hypernetwork.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            target = target.squeeze()

            client_outputs = [client(data) for client in clients]
            aggregated_output = torch.cat(client_outputs, dim=1)
            hypernetwork_output = hypernetwork(aggregated_output)
            server_output = server(hypernetwork_output)
            test_loss += F.nll_loss(server_output, target, reduction='sum').item()
            pred = server_output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({100. * correct / len(test_loader.dataset):.0f}%)\n')

# Run training and testing
for epoch in range(1, 100):
    train(epoch)
test()


Using downloaded and verified file: /root/.medmnist/breastmnist.npz
Using downloaded and verified file: /root/.medmnist/breastmnist.npz
Epoch 1: Loss: 0.6930980086326599
Epoch 2: Loss: 0.6838165521621704
Epoch 3: Loss: 0.6938560605049133
Epoch 4: Loss: 0.672053873538971
Epoch 5: Loss: 0.6876316666603088
Epoch 6: Loss: 0.6692244410514832
Epoch 7: Loss: 0.6842911839485168
Epoch 8: Loss: 0.678032636642456
Epoch 9: Loss: 0.6776852607727051
Epoch 10: Loss: 0.642151415348053
Epoch 11: Loss: 0.6593918204307556
Epoch 12: Loss: 0.6857162117958069
Epoch 13: Loss: 0.6832902431488037
Epoch 14: Loss: 0.6568840742111206
Epoch 15: Loss: 0.6704747676849365
Epoch 16: Loss: 0.7113478779792786
Epoch 17: Loss: 0.6652343273162842
Epoch 18: Loss: 0.6235612034797668
Epoch 19: Loss: 0.6905626654624939
Epoch 20: Loss: 0.6620615124702454
Epoch 21: Loss: 0.6399678587913513
Epoch 22: Loss: 0.6233506202697754
Epoch 23: Loss: 0.6755709648132324
Epoch 24: Loss: 0.6924368739128113
Epoch 25: Loss: 0.5685511231422424
E

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import medmnist
from medmnist import INFO
from torch.utils.data import DataLoader, Subset
from torchvision import transforms

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Choose the BreastMNIST dataset
data_flag = 'breastmnist'
info = INFO[data_flag]
n_channels = info['n_channels']
n_classes = 2  # Binary classification for BreastMNIST

# Define transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

# Load the dataset with transforms
DataClass = getattr(medmnist, info['python_class'])
train_data = DataClass(split='train', download=True, transform=transform)
test_data = DataClass(split='test', download=True, transform=transform)

# Split data into 2 parts for clients
split_size = len(train_data) // 2
client_datasets = [Subset(train_data, range(i * split_size, (i + 1) * split_size)) for i in range(2)]

# Create data loaders for each client
batch_size = 32
client_loaders = [DataLoader(dataset, batch_size=batch_size, shuffle=True) for dataset in client_datasets]
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

# Client Model
class ClientModel(nn.Module):
    def __init__(self):
        super(ClientModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.fc = nn.Linear(7 * 7 * 32, 1568)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

# Hypernetwork
class Hypernetwork(nn.Module):
    def __init__(self):
        super(Hypernetwork, self).__init__()
        self.fc = nn.Linear(2 * 1568, 12544)

    def forward(self, x):
        return self.fc(x)

# Server Model
class ServerModel(nn.Module):
    def __init__(self):
        super(ServerModel, self).__init__()
        self.fc1 = nn.Linear(12544, 128)
        self.fc2 = nn.Linear(128, n_classes)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        return F.log_softmax(self.fc2(x), dim=1)

# Initialize models and move to device
clients = [ClientModel().to(device) for _ in range(2)]
hypernetwork = Hypernetwork().to(device)
server = ServerModel().to(device)

# Define optimizers
client_optimizers = [optim.SGD(client.parameters(), lr=0.01) for client in clients]
hypernetwork_optimizer = optim.SGD(hypernetwork.parameters(), lr=0.01)
server_optimizer = optim.SGD(server.parameters(), lr=0.01)

# Training function
def train(epoch):
    server.train()
    hypernetwork.train()
    for client, optimizer in zip(clients, client_optimizers):
        client.train()
        for data, target in client_loaders[clients.index(client)]:
            data, target = data.to(device), target.to(device)
            target = target.squeeze()

            optimizer.zero_grad()
            client_output = client(data)

            # Concatenate client outputs for the hypernetwork
            aggregated_output = torch.cat([client_output for client in clients], dim=1)

            hypernetwork_output = hypernetwork(aggregated_output)
            server_output = server(hypernetwork_output)
            loss = F.nll_loss(server_output, target)
            loss.backward()
            optimizer.step()

    print(f"Epoch {epoch}: Loss: {loss.item()}")


# def train(epoch):
#     server.train()
#     hypernetwork.train()
#     for client, optimizer in zip(clients, client_optimizers):
#         client.train()
#         for data, target in client_loaders[clients.index(client)]:
#             data, target = data.to(device), target.to(device)
#             target = target.squeeze()

#             optimizer.zero_grad()
#             client_output = client(data)
#             hypernetwork_output = hypernetwork(client_output)
#             server_output = server(hypernetwork_output)
#             loss = F.nll_loss(server_output, target)
#             loss.backward()
#             optimizer.step()

#     print(f"Epoch {epoch}: Loss: {loss.item()}")

# Testing function
def test():
    server.eval()
    hypernetwork.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            target = target.squeeze()

            client_outputs = [client(data) for client in clients]
            aggregated_output = torch.cat(client_outputs, dim=1)
            hypernetwork_output = hypernetwork(aggregated_output)
            server_output = server(hypernetwork_output)
            test_loss += F.nll_loss(server_output, target, reduction='sum').item()
            pred = server_output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({100. * correct / len(test_loader.dataset):.0f}%)\n')

# Run training and testing
for epoch in range(1, 100):
    train(epoch)
test()


Using downloaded and verified file: /root/.medmnist/breastmnist.npz
Using downloaded and verified file: /root/.medmnist/breastmnist.npz
Epoch 1: Loss: 0.6754650473594666
Epoch 2: Loss: 0.6561296582221985
Epoch 3: Loss: 0.6628713011741638
Epoch 4: Loss: 0.6412376165390015
Epoch 5: Loss: 0.6693846583366394
Epoch 6: Loss: 0.6676784157752991
Epoch 7: Loss: 0.6293926239013672
Epoch 8: Loss: 0.6508394479751587
Epoch 9: Loss: 0.6944422721862793
Epoch 10: Loss: 0.6154557466506958
Epoch 11: Loss: 0.6283389329910278
Epoch 12: Loss: 0.6770913004875183
Epoch 13: Loss: 0.6030018925666809
Epoch 14: Loss: 0.5969135165214539
Epoch 15: Loss: 0.6581016778945923
Epoch 16: Loss: 0.6512589454650879
Epoch 17: Loss: 0.609336256980896
Epoch 18: Loss: 0.5751552581787109
Epoch 19: Loss: 0.6743370294570923
Epoch 20: Loss: 0.5975362658500671
Epoch 21: Loss: 0.5665944218635559
Epoch 22: Loss: 0.7112487554550171
Epoch 23: Loss: 0.6511773467063904
Epoch 24: Loss: 0.6499750018119812
Epoch 25: Loss: 0.7133002281188965