In [None]:
!pip install torch torchvision
!pip install medmnist

Collecting medmnist
  Downloading medmnist-3.0.1-py3-none-any.whl (25 kB)
Collecting fire (from medmnist)
  Downloading fire-0.5.0.tar.gz (88 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m88.3/88.3 kB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: fire
  Building wheel for fire (setup.py) ... [?25l[?25hdone
  Created wheel for fire: filename=fire-0.5.0-py2.py3-none-any.whl size=116934 sha256=2a1fbb49ccbfc96e09b4dc2e025426c46cb694b84eceafc04cc8907fc2a350eb
  Stored in directory: /root/.cache/pip/wheels/90/d4/f7/9404e5db0116bd4d43e5666eaa3e70ab53723e1e3ea40c9a95
Successfully built fire
Installing collected packages: fire, medmnist
Successfully installed fire-0.5.0 medmnist-3.0.1


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import medmnist
from medmnist import INFO
from torch.utils.data import DataLoader, Subset
from torchvision import transforms

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Choose the PneumoniaMNIST dataset
data_flag = 'pneumoniamnist'
info = INFO[data_flag]
n_channels = info['n_channels']
n_classes = 2  # Binary classification for PneumoniaMNIST

# Define transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

# Load the dataset with transforms
DataClass = getattr(medmnist, info['python_class'])
train_data = DataClass(split='train', download=True, transform=transform)
test_data = DataClass(split='test', download=True, transform=transform)

# Split data into 2 parts for clients
split_size = len(train_data) // 2
client_datasets = [Subset(train_data, range(i * split_size, (i + 1) * split_size)) for i in range(2)]

# Create data loaders for each client
batch_size = 32
client_loaders = [DataLoader(dataset, batch_size=batch_size, shuffle=True) for dataset in client_datasets]
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

# Client Model
class ClientModel(nn.Module):
    def __init__(self):
        super(ClientModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.fc = nn.Linear(7 * 7 * 32, 1568)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

# Hypernetwork
class Hypernetwork(nn.Module):
    def __init__(self):
        super(Hypernetwork, self).__init__()
        self.fc = nn.Linear(2 * 1568, 3136)

    def forward(self, x):
        return self.fc(x)

# Server Model
class ServerModel(nn.Module):
    def __init__(self):
        super(ServerModel, self).__init__()
        self.fc1 = nn.Linear(3136, 128)
        self.fc2 = nn.Linear(128, n_classes)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        return F.log_softmax(self.fc2(x), dim=1)

# Initialize models and move to device
clients = [ClientModel().to(device) for _ in range(2)]
hypernetwork = Hypernetwork().to(device)
server = ServerModel().to(device)

# Define optimizers
client_optimizers = [optim.SGD(client.parameters(), lr=0.01) for client in clients]
hypernetwork_optimizer = optim.SGD(hypernetwork.parameters(), lr=0.01)
server_optimizer = optim.SGD(server.parameters(), lr=0.01)

# Training function
def train(epoch):
    server.train()
    hypernetwork.train()
    for client, optimizer in zip(clients, client_optimizers):
        client.train()
        for data, target in client_loaders[clients.index(client)]:
            data, target = data.to(device), target.to(device)
            target = target.squeeze()

            optimizer.zero_grad()
            client_output = client(data)

            # Concatenate client outputs for the hypernetwork
            aggregated_output = torch.cat([client_output for client in clients], dim=1)

            hypernetwork_output = hypernetwork(aggregated_output)
            server_output = server(hypernetwork_output)
            loss = F.nll_loss(server_output, target)
            loss.backward()
            optimizer.step()

    print(f"Epoch {epoch}: Loss: {loss.item()}")

# def train(epoch):
#     server.train()
#     hypernetwork.train()
#     for client, optimizer in zip(clients, client_optimizers):
#         client.train()
#         for data, target in client_loaders[clients.index(client)]:
#             data, target = data.to(device), target.to(device)
#             target = target.squeeze()

#             optimizer.zero_grad()
#             client_output = client(data)
#             hypernetwork_output = hypernetwork(client_output)
#             server_output = server(hypernetwork_output)
#             loss = F.nll_loss(server_output, target)
#             loss.backward()
#             optimizer.step()

#     print(f"Epoch {epoch}: Loss: {loss.item()}")

# Testing function
def test():
    server.eval()
    hypernetwork.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            target = target.squeeze()

            client_outputs = [client(data) for client in clients]
            aggregated_output = torch.cat(client_outputs, dim=1)
            hypernetwork_output = hypernetwork(aggregated_output)
            server_output = server(hypernetwork_output)
            test_loss += F.nll_loss(server_output, target, reduction='sum').item()
            pred = server_output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({100. * correct / len(test_loader.dataset):.0f}%)\n')

# Run training and testing
for epoch in range(1, 50):
    train(epoch)
test()


Downloading https://zenodo.org/records/10519652/files/pneumoniamnist.npz?download=1 to /root/.medmnist/pneumoniamnist.npz


100%|██████████| 4170669/4170669 [00:01<00:00, 2956916.62it/s]


Using downloaded and verified file: /root/.medmnist/pneumoniamnist.npz
Epoch 1: Loss: 0.6733518838882446
Epoch 2: Loss: 0.6655266284942627
Epoch 3: Loss: 0.6521260142326355
Epoch 4: Loss: 0.42444413900375366
Epoch 5: Loss: 0.41548973321914673
Epoch 6: Loss: 0.5763504505157471
Epoch 7: Loss: 0.32338136434555054
Epoch 8: Loss: 0.4009271264076233
Epoch 9: Loss: 0.5862756371498108
Epoch 10: Loss: 0.39928188920021057
Epoch 11: Loss: 0.5817214250564575
Epoch 12: Loss: 0.4661250710487366
Epoch 13: Loss: 0.3591834306716919
Epoch 14: Loss: 0.33728569746017456
Epoch 15: Loss: 0.3599308133125305
Epoch 16: Loss: 0.5045976638793945
Epoch 17: Loss: 0.5756531953811646
Epoch 18: Loss: 0.33842259645462036
Epoch 19: Loss: 0.46728765964508057
Epoch 20: Loss: 0.5355247855186462
Epoch 21: Loss: 0.5731016993522644
Epoch 22: Loss: 0.3943222463130951
Epoch 23: Loss: 0.540066659450531
Epoch 24: Loss: 0.41308891773223877
Epoch 25: Loss: 0.3283945322036743
Epoch 26: Loss: 0.3636084794998169
Epoch 27: Loss: 0.345

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import medmnist
from medmnist import INFO
from torch.utils.data import DataLoader, Subset
from torchvision import transforms

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Choose the pneumoniamnist dataset
data_flag = 'pneumoniamnist'
info = INFO[data_flag]
n_channels = info['n_channels']
n_classes = 2  # Binary classification for pneumoniamnist

# Define transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

# Load the dataset with transforms
DataClass = getattr(medmnist, info['python_class'])
train_data = DataClass(split='train', download=True, transform=transform)
test_data = DataClass(split='test', download=True, transform=transform)

# Split data into 2 parts for clients
split_size = len(train_data) // 2
client_datasets = [Subset(train_data, range(i * split_size, (i + 1) * split_size)) for i in range(2)]

# Create data loaders for each client
batch_size = 32
client_loaders = [DataLoader(dataset, batch_size=batch_size, shuffle=True) for dataset in client_datasets]
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

# Client Model
class ClientModel(nn.Module):
    def __init__(self):
        super(ClientModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.fc = nn.Linear(7 * 7 * 32, 1568)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

# Hypernetwork
class Hypernetwork(nn.Module):
    def __init__(self):
        super(Hypernetwork, self).__init__()
        self.fc = nn.Linear(2 * 1568, 6272)

    def forward(self, x):
        return self.fc(x)

# Server Model
class ServerModel(nn.Module):
    def __init__(self):
        super(ServerModel, self).__init__()
        self.fc1 = nn.Linear(6272, 128)
        self.fc2 = nn.Linear(128, n_classes)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        return F.log_softmax(self.fc2(x), dim=1)

# Initialize models and move to device
clients = [ClientModel().to(device) for _ in range(2)]
hypernetwork = Hypernetwork().to(device)
server = ServerModel().to(device)

# Define optimizers
client_optimizers = [optim.SGD(client.parameters(), lr=0.01) for client in clients]
hypernetwork_optimizer = optim.SGD(hypernetwork.parameters(), lr=0.01)
server_optimizer = optim.SGD(server.parameters(), lr=0.01)

# Training function
def train(epoch):
    server.train()
    hypernetwork.train()
    for client, optimizer in zip(clients, client_optimizers):
        client.train()
        for data, target in client_loaders[clients.index(client)]:
            data, target = data.to(device), target.to(device)
            target = target.squeeze()

            optimizer.zero_grad()
            client_output = client(data)

            # Concatenate client outputs for the hypernetwork
            aggregated_output = torch.cat([client_output for client in clients], dim=1)

            hypernetwork_output = hypernetwork(aggregated_output)
            server_output = server(hypernetwork_output)
            loss = F.nll_loss(server_output, target)
            loss.backward()
            optimizer.step()

    print(f"Epoch {epoch}: Loss: {loss.item()}")

# def train(epoch):
#     server.train()
#     hypernetwork.train()
#     for client, optimizer in zip(clients, client_optimizers):
#         client.train()
#         for data, target in client_loaders[clients.index(client)]:
#             data, target = data.to(device), target.to(device)
#             target = target.squeeze()

#             optimizer.zero_grad()
#             client_output = client(data)
#             hypernetwork_output = hypernetwork(client_output)
#             server_output = server(hypernetwork_output)
#             loss = F.nll_loss(server_output, target)
#             loss.backward()
#             optimizer.step()

#     print(f"Epoch {epoch}: Loss: {loss.item()}")

# Testing function
def test():
    server.eval()
    hypernetwork.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            target = target.squeeze()

            client_outputs = [client(data) for client in clients]
            aggregated_output = torch.cat(client_outputs, dim=1)
            hypernetwork_output = hypernetwork(aggregated_output)
            server_output = server(hypernetwork_output)
            test_loss += F.nll_loss(server_output, target, reduction='sum').item()
            pred = server_output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({100. * correct / len(test_loader.dataset):.0f}%)\n')

# Run training and testing
for epoch in range(1, 100):
    train(epoch)
test()


Using downloaded and verified file: /root/.medmnist/pneumoniamnist.npz
Using downloaded and verified file: /root/.medmnist/pneumoniamnist.npz
Epoch 1: Loss: 0.6921656131744385
Epoch 2: Loss: 0.6226744055747986
Epoch 3: Loss: 0.49891021847724915
Epoch 4: Loss: 0.5279533863067627
Epoch 5: Loss: 0.41030046343803406
Epoch 6: Loss: 0.5874865055084229
Epoch 7: Loss: 0.540465235710144
Epoch 8: Loss: 0.6240466237068176
Epoch 9: Loss: 0.41853514313697815
Epoch 10: Loss: 0.43741363286972046
Epoch 11: Loss: 0.4619804620742798
Epoch 12: Loss: 0.524450957775116
Epoch 13: Loss: 0.5939149260520935
Epoch 14: Loss: 0.44500336050987244
Epoch 15: Loss: 0.38639166951179504
Epoch 16: Loss: 0.45757707953453064
Epoch 17: Loss: 0.5665512084960938
Epoch 18: Loss: 0.40811577439308167
Epoch 19: Loss: 0.42653530836105347
Epoch 20: Loss: 0.292697548866272
Epoch 21: Loss: 0.23548801243305206
Epoch 22: Loss: 0.2106475830078125
Epoch 23: Loss: 0.21919457614421844
Epoch 24: Loss: 0.3307584822177887
Epoch 25: Loss: 0.3

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import medmnist
from medmnist import INFO
from torch.utils.data import DataLoader, Subset
from torchvision import transforms

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Choose the pneumoniamnist dataset
data_flag = 'pneumoniamnist'
info = INFO[data_flag]
n_channels = info['n_channels']
n_classes = 2  # Binary classification for pneumoniamnist

# Define transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

# Load the dataset with transforms
DataClass = getattr(medmnist, info['python_class'])
train_data = DataClass(split='train', download=True, transform=transform)
test_data = DataClass(split='test', download=True, transform=transform)

# Split data into 2 parts for clients
split_size = len(train_data) // 2
client_datasets = [Subset(train_data, range(i * split_size, (i + 1) * split_size)) for i in range(2)]

# Create data loaders for each client
batch_size = 32
client_loaders = [DataLoader(dataset, batch_size=batch_size, shuffle=True) for dataset in client_datasets]
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

# Client Model
class ClientModel(nn.Module):
    def __init__(self):
        super(ClientModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.fc = nn.Linear(7 * 7 * 32, 1568)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

# Hypernetwork
class Hypernetwork(nn.Module):
    def __init__(self):
        super(Hypernetwork, self).__init__()
        self.fc = nn.Linear(2 * 1568, 9408)

    def forward(self, x):
        return self.fc(x)

# Server Model
class ServerModel(nn.Module):
    def __init__(self):
        super(ServerModel, self).__init__()
        self.fc1 = nn.Linear(9408,128)
        self.fc2 = nn.Linear(128, n_classes)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        return F.log_softmax(self.fc2(x), dim=1)

# Initialize models and move to device
clients = [ClientModel().to(device) for _ in range(2)]
hypernetwork = Hypernetwork().to(device)
server = ServerModel().to(device)

# Define optimizers
client_optimizers = [optim.SGD(client.parameters(), lr=0.01) for client in clients]
hypernetwork_optimizer = optim.SGD(hypernetwork.parameters(), lr=0.01)
server_optimizer = optim.SGD(server.parameters(), lr=0.01)

# Training function
def train(epoch):
    server.train()
    hypernetwork.train()
    for client, optimizer in zip(clients, client_optimizers):
        client.train()
        for data, target in client_loaders[clients.index(client)]:
            data, target = data.to(device), target.to(device)
            target = target.squeeze()

            optimizer.zero_grad()
            client_output = client(data)

            # Concatenate client outputs for the hypernetwork
            aggregated_output = torch.cat([client_output for client in clients], dim=1)

            hypernetwork_output = hypernetwork(aggregated_output)
            server_output = server(hypernetwork_output)
            loss = F.nll_loss(server_output, target)
            loss.backward()
            optimizer.step()

    print(f"Epoch {epoch}: Loss: {loss.item()}")

# def train(epoch):
#     server.train()
#     hypernetwork.train()
#     for client, optimizer in zip(clients, client_optimizers):
#         client.train()
#         for data, target in client_loaders[clients.index(client)]:
#             data, target = data.to(device), target.to(device)
#             target = target.squeeze()

#             optimizer.zero_grad()
#             client_output = client(data)
#             hypernetwork_output = hypernetwork(client_output)
#             server_output = server(hypernetwork_output)
#             loss = F.nll_loss(server_output, target)
#             loss.backward()
#             optimizer.step()

#     print(f"Epoch {epoch}: Loss: {loss.item()}")

# Testing function
def test():
    server.eval()
    hypernetwork.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            target = target.squeeze()

            client_outputs = [client(data) for client in clients]
            aggregated_output = torch.cat(client_outputs, dim=1)
            hypernetwork_output = hypernetwork(aggregated_output)
            server_output = server(hypernetwork_output)
            test_loss += F.nll_loss(server_output, target, reduction='sum').item()
            pred = server_output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({100. * correct / len(test_loader.dataset):.0f}%)\n')

# Run training and testing
for epoch in range(1, 100):
    train(epoch)
test()


Using downloaded and verified file: /root/.medmnist/pneumoniamnist.npz
Using downloaded and verified file: /root/.medmnist/pneumoniamnist.npz
Epoch 1: Loss: 0.680432140827179
Epoch 2: Loss: 0.6306365728378296
Epoch 3: Loss: 0.5416367053985596
Epoch 4: Loss: 0.7692310810089111
Epoch 5: Loss: 0.5918901562690735
Epoch 6: Loss: 0.6436994671821594
Epoch 7: Loss: 0.41212257742881775
Epoch 8: Loss: 0.47893026471138
Epoch 9: Loss: 0.4861050844192505
Epoch 10: Loss: 0.40638914704322815
Epoch 11: Loss: 0.7082064747810364
Epoch 12: Loss: 0.49489593505859375
Epoch 13: Loss: 0.5097054243087769
Epoch 14: Loss: 0.5312706828117371
Epoch 15: Loss: 0.5541714429855347
Epoch 16: Loss: 0.5407376289367676
Epoch 17: Loss: 0.6039213538169861
Epoch 18: Loss: 0.49199825525283813
Epoch 19: Loss: 0.43319809436798096
Epoch 20: Loss: 0.48157745599746704
Epoch 21: Loss: 0.5377501845359802
Epoch 22: Loss: 0.2899913787841797
Epoch 23: Loss: 0.34463822841644287
Epoch 24: Loss: 0.3387990891933441
Epoch 25: Loss: 0.29109

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import medmnist
from medmnist import INFO
from torch.utils.data import DataLoader, Subset
from torchvision import transforms

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Choose the pneumoniamnist dataset
data_flag = 'pneumoniamnist'
info = INFO[data_flag]
n_channels = info['n_channels']
n_classes = 2  # Binary classification for pneumoniamnist

# Define transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

# Load the dataset with transforms
DataClass = getattr(medmnist, info['python_class'])
train_data = DataClass(split='train', download=True, transform=transform)
test_data = DataClass(split='test', download=True, transform=transform)

# Split data into 2 parts for clients
split_size = len(train_data) // 2
client_datasets = [Subset(train_data, range(i * split_size, (i + 1) * split_size)) for i in range(2)]

# Create data loaders for each client
batch_size = 32
client_loaders = [DataLoader(dataset, batch_size=batch_size, shuffle=True) for dataset in client_datasets]
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

# Client Model
class ClientModel(nn.Module):
    def __init__(self):
        super(ClientModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.fc = nn.Linear(7 * 7 * 32, 1568)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

# Hypernetwork
class Hypernetwork(nn.Module):
    def __init__(self):
        super(Hypernetwork, self).__init__()
        self.fc = nn.Linear(2 * 1568, 12544)

    def forward(self, x):
        return self.fc(x)

# Server Model
class ServerModel(nn.Module):
    def __init__(self):
        super(ServerModel, self).__init__()
        self.fc1 = nn.Linear(12544, 128)
        self.fc2 = nn.Linear(128, n_classes)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        return F.log_softmax(self.fc2(x), dim=1)

# Initialize models and move to device
clients = [ClientModel().to(device) for _ in range(2)]
hypernetwork = Hypernetwork().to(device)
server = ServerModel().to(device)

# Define optimizers
client_optimizers = [optim.SGD(client.parameters(), lr=0.01) for client in clients]
hypernetwork_optimizer = optim.SGD(hypernetwork.parameters(), lr=0.01)
server_optimizer = optim.SGD(server.parameters(), lr=0.01)

# Training function
def train(epoch):
    server.train()
    hypernetwork.train()
    for client, optimizer in zip(clients, client_optimizers):
        client.train()
        for data, target in client_loaders[clients.index(client)]:
            data, target = data.to(device), target.to(device)
            target = target.squeeze()

            optimizer.zero_grad()
            client_output = client(data)

            # Concatenate client outputs for the hypernetwork
            aggregated_output = torch.cat([client_output for client in clients], dim=1)

            hypernetwork_output = hypernetwork(aggregated_output)
            server_output = server(hypernetwork_output)
            loss = F.nll_loss(server_output, target)
            loss.backward()
            optimizer.step()

    print(f"Epoch {epoch}: Loss: {loss.item()}")


# def train(epoch):
#     server.train()
#     hypernetwork.train()
#     for client, optimizer in zip(clients, client_optimizers):
#         client.train()
#         for data, target in client_loaders[clients.index(client)]:
#             data, target = data.to(device), target.to(device)
#             target = target.squeeze()

#             optimizer.zero_grad()
#             client_output = client(data)
#             hypernetwork_output = hypernetwork(client_output)
#             server_output = server(hypernetwork_output)
#             loss = F.nll_loss(server_output, target)
#             loss.backward()
#             optimizer.step()

#     print(f"Epoch {epoch}: Loss: {loss.item()}")

# Testing function
def test():
    server.eval()
    hypernetwork.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            target = target.squeeze()

            client_outputs = [client(data) for client in clients]
            aggregated_output = torch.cat(client_outputs, dim=1)
            hypernetwork_output = hypernetwork(aggregated_output)
            server_output = server(hypernetwork_output)
            test_loss += F.nll_loss(server_output, target, reduction='sum').item()
            pred = server_output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({100. * correct / len(test_loader.dataset):.0f}%)\n')

# Run training and testing
for epoch in range(1, 100):
    train(epoch)
test()


Using downloaded and verified file: /root/.medmnist/pneumoniamnist.npz
Using downloaded and verified file: /root/.medmnist/pneumoniamnist.npz
Epoch 1: Loss: 0.6120525002479553
Epoch 2: Loss: 0.6775522232055664
Epoch 3: Loss: 0.5832805633544922
Epoch 4: Loss: 0.37631213665008545
Epoch 5: Loss: 0.4235135316848755
Epoch 6: Loss: 0.5206459760665894
Epoch 7: Loss: 0.7213881015777588
Epoch 8: Loss: 0.4024496376514435
Epoch 9: Loss: 0.585258960723877
Epoch 10: Loss: 0.49598339200019836
Epoch 11: Loss: 0.5632988810539246
Epoch 12: Loss: 0.503531277179718
Epoch 13: Loss: 0.6039111018180847
Epoch 14: Loss: 0.5430577993392944
Epoch 15: Loss: 0.5195146799087524
Epoch 16: Loss: 0.4415673315525055
Epoch 17: Loss: 0.489604115486145
Epoch 18: Loss: 0.4360673427581787
Epoch 19: Loss: 0.5297596454620361
Epoch 20: Loss: 0.4747216999530792
Epoch 21: Loss: 0.4676307439804077
Epoch 22: Loss: 0.31709402799606323
Epoch 23: Loss: 0.27337950468063354
Epoch 24: Loss: 0.33418604731559753
Epoch 25: Loss: 0.2600495