<a href="https://colab.research.google.com/github/Rahad31/LLM-KL-FedDis/blob/main/LLM_KLFeddis.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
from scipy.stats import truncnorm
import torch
import torchvision
import torchvision.transforms as transforms
from torch import nn
from torch.utils.data import Subset, DataLoader
from torchvision.datasets import CIFAR10
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score
from transformers import pipeline
import random

# ------------------ LLM Setup ------------------
llm = pipeline("text-generation", model="gpt2")

def get_llm_suggestion(distribution_info: dict) -> str:
    mean = np.round(distribution_info["normal"]["mean"][:5], 2)
    std = np.round(distribution_info["normal"]["std"][:5], 2)
    prompt = f"Client observed mean {mean.tolist()} and std {std.tolist()}. Accuracy dropped below 70%. Suggest action:"
    suggestion = llm(prompt, max_length=50, num_return_sequences=1)
    suggestion_text = suggestion[0]['generated_text']
    print(f"\nüì£ LLM Suggestion:\n{suggestion_text}\n")
    return suggestion_text

# ------------------ VAE ------------------
class VAE(nn.Module):
    def __init__(self, latent_dim=20):
        super(VAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, 4, 2, 1), nn.ReLU(),
            nn.Conv2d(32, 64, 4, 2, 1), nn.ReLU(),
            nn.Flatten()
        )
        self.fc_mu = nn.Linear(64 * 8 * 8, latent_dim)
        self.fc_logvar = nn.Linear(64 * 8 * 8, latent_dim)
        self.decoder_input = nn.Linear(latent_dim, 64 * 8 * 8)
        self.decoder = nn.Sequential(
            nn.Unflatten(1, (64, 8, 8)),
            nn.ConvTranspose2d(64, 32, 4, 2, 1), nn.ReLU(),
            nn.ConvTranspose2d(32, 3, 4, 2, 1), nn.Sigmoid()
        )

    def encode(self, x):
        h = self.encoder(x)
        return self.fc_mu(h), self.fc_logvar(h)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h = self.decoder_input(z)
        return self.decoder(h)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

# ------------------ CNN Classifier ------------------
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, 32, 3, 1, 1), nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, 1, 1), nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64 * 8 * 8, 256), nn.ReLU(),
            nn.Linear(256, 10)
        )

    def forward(self, x):
        return self.net(x)

# ------------------ Helper Functions ------------------
def vae_train(vae, dataloader, epochs=1):
    vae.train()
    optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)
    criterion = nn.MSELoss()
    for _ in range(epochs):
        for x, _ in dataloader:
            x = x.cuda()
            x_recon, mu, logvar = vae(x)
            recon_loss = criterion(x_recon, x)
            kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / x.shape[0]
            loss = recon_loss + 0.001 * kl_div
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

def train(model, trainloader, valloader, epochs=1):
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()
    for _ in range(epochs):
        for x, y in trainloader:
            x, y = x.cuda(), y.cuda()
            logits = model(x)
            loss = criterion(logits, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

def get_distribution_info(vae):
    with torch.no_grad():
        z = torc

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Device set to use cpu


In [None]:
# VAE + LLM-Coordinated Federated Learning (CIFAR-10)

import numpy as np
from scipy.stats import truncnorm
import torch
import torchvision
import torchvision.transforms as transforms
from torch import nn
from torch.utils.data import Subset, DataLoader
from torchvision.datasets import CIFAR10
from typing import Dict, Tuple
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score
import random
from transformers import pipeline

# Initialize LLM (GPT-2)
llm = pipeline("text-generation", model="gpt2")

class VAE(nn.Module):
    def __init__(self, x_dim, h_dim, z_dim):
        super(VAE, self).__init__()
        self.z_dim = z_dim
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(2048, 256),
            nn.ReLU(),
            nn.Linear(256, 2 * z_dim)
        )
        self.decoder = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 2048),
            nn.ReLU(),
            nn.Unflatten(1, (128, 4, 4)),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, 4, stride=2, padding=1),
            nn.Sigmoid()
        )

    def encode(self, x):
        h = self.encoder(x)
        mu, logvar = h[:, :self.z_dim], h[:, self.z_dim:]
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(nn.functional.relu(self.conv1(x)))
        x = self.pool(nn.functional.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = nn.functional.relu(self.fc1(x))
        x = nn.functional.relu(self.fc2(x))
        return self.fc3(x)

transform = transforms.Compose([
    transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
full_dataset = CIFAR10(root="./data", train=True, download=True, transform=transform)
test_set = CIFAR10(root="./data", train=False, download=True, transform=transform)
train_set, val_set = torch.utils.data.random_split(full_dataset, [int(0.8 * len(full_dataset)), len(full_dataset) - int(0.8 * len(full_dataset))])
trainloader = DataLoader(train_set, batch_size=128, shuffle=True)
valloader = DataLoader(val_set, batch_size=128, shuffle=False)
testloader = DataLoader(test_set, batch_size=128, shuffle=False)

def vae_loss(recon_x, x, mu, logvar):
    BCE = nn.functional.mse_loss(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

def vae_train(vae, loader, epochs):
    opt = torch.optim.Adam(vae.parameters(), lr=1e-3)
    for epoch in range(epochs):
        vae.train()
        for x, _ in loader:
            opt.zero_grad()
            recon_x, mu, logvar = vae(x)
            loss = vae_loss(recon_x, x, mu, logvar)
            loss.backward()
            opt.step()

def train(net, trainloader, valloader, epochs):
    opt = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
    crit = nn.CrossEntropyLoss()
    for epoch in range(epochs):
        net.train()
        for x, y in trainloader:
            opt.zero_grad()
            out = net(x)
            loss = crit(out, y)
            loss.backward()
            opt.step()

def evaluate(net, loader):
    net.eval()
    total, correct = 0, 0
    for x, y in loader:
        out = net(x)
        _, pred = out.max(1)
        total += y.size(0)
        correct += pred.eq(y).sum().item()
    print(f"Test Accuracy: {100 * correct / total:.2f}%")

def get_distribution_info(vae):
    return {
        "normal": {
            "mean": vae.encoder[-1].bias.data.cpu().numpy(),
            "std": torch.exp(0.5 * vae.encoder[-1].weight.data).cpu().numpy()
        }
    }

def send_distribution_info(info):
    mean = np.round(info['normal']['mean'][:5], 2)
    std = np.round(info['normal']['std'][:5], 2)
    summary = f"Mean: {mean}, Std: {std}"
    output = llm(f"Suggest strategy for client with: {summary}", max_length=40)[0]['generated_text']
    print("\n[LLM Feedback]", output)
    return output

def receive_distribution_info():
    return {"normal": {"mean": np.zeros(20), "std": np.ones(20)}}

def generate_augmented_data(vae, dist):
    return torch.randn(64, vae.z_dim) * dist['std'] + dist['mean']

def initialize_clients(dataset, transform, batch_size, n):
    size = len(dataset) // n
    return {f"client_{i}": DataLoader(Subset(dataset, range(i * size, (i + 1) * size)), batch_size=batch_size, shuffle=True) for i in range(n)}

def federated_train(net, vae, trainloaders, valloader):
    for client_id, loader in trainloaders.items():
        print(f"\nClient: {client_id}")
        vae_train(vae, loader, epochs=1)
        info = get_distribution_info(vae)
        feedback = send_distribution_info(info)
        if "augment" in feedback:
            print("[Client Action] Applying augmentation...")
        train(net, loader, valloader, epochs=1)

def global_server():
    net = Net()
    vae = VAE(3*32*32, 400, 20)
    clients = initialize_clients(train_set, transform, 128, 5)
    federated_train(net, vae, clients, valloader)
    evaluate(net, testloader)

if __name__ == "__main__":
    global_server()


Device set to use cpu
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 170M/170M [00:02<00:00, 78.0MB/s]



Client: client_0


Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Both `max_new_tokens` (=256) and `max_length`(=40) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)



[LLM Feedback] Suggest strategy for client with: Mean: [ 0.01  0.06  0.05 -0.04  0.03], Std: [[0.99 0.97 1.01 ... 0.98 0.97 1.03]
 [1.01 1.01 1.02 ... 1.   1.01 0.97]
 [0.97 1.01 1.01 ... 0.98 0.98 1.  ]
 [1.02 1.   1.   ... 1.02 1.01 0.99]
 [1.02 1.01 0.99 ... 1.   1.   0.98]].00 -0.01 0.04 0.03], Std: [[0.99 0.98 1.00 -1.00 1.00 0.03], Std: [[0.99 0.99 1.00 -1.00 1.00 0.03], Std: [[0.99 1.01 1.00 -1.00 1.00 0.03], Std: [[0.99 1.02 1.00 -1.00 1.00 0.03], Std: [[0.99 1.03 1.00 -1.00 1.00 0.03], Std: [[0.99 1.04 1.00 -1.00 1.00 0.03], Std: [[0.99 1.05 1.00 -1.00 1.00 0.03], Std: [[0.99 1.06 1.00 -1.00 1.00 0.03], Std: [[0.99 1.07 1.00 -1.00 1.00 0.03], Std: [[0.99 1.08 1.00 -1.00 1.00 0.03], Std:

Client: client_1


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Both `max_new_tokens` (=256) and `max_length`(=40) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)



[LLM Feedback] Suggest strategy for client with: Mean: [ 0.02  0.05  0.07 -0.06  0.03], Std: [[1.   0.97 1.01 ... 0.98 0.97 1.03]
 [1.01 1.01 1.02 ... 1.   1.01 0.97]
 [0.97 1.01 1.01 ... 0.98 0.99 1.  ]
 [1.01 1.   1.   ... 1.02 1.01 0.99]
 [1.02 1.01 0.99 ... 1.   1.   0.98]] -0.21  0.20 -0.03 -0.03 -0.03], Std: [0. ¬† 0.15 -0.03 -0.02 +0.02 -0.02], Std: [0. ¬† 0.07 -0.06 -0.01 -0.02 -0.02], Std: [0. ¬† 0.01 -0.00 -0.03 -0.02 -0.03], Std: [0. ¬† 0.03 -0.02 -0.01 +0.02 +0.02], Std: [0. ¬† 0.02 +0.06 +0.02 +0.02] Mean: [ 0.02  0.05  0.07 -0.06  0.03], Std: [[1.  0.97 -0.21  0.20 -0.03 -0.03 -0.03], Std: [0. ¬† 0.15 -0.03 -0.02 +0.02 -0.02], Std: [0. ¬† 0.07 -0.06 -0

Client: client_2


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Both `max_new_tokens` (=256) and `max_length`(=40) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)



[LLM Feedback] Suggest strategy for client with: Mean: [ 0.01  0.05  0.06 -0.06  0.04], Std: [[1.   0.97 1.01 ... 0.98 0.97 1.03]
 [1.01 1.01 1.02 ... 1.   1.01 0.97]
 [0.97 1.01 1.01 ... 0.98 0.99 1.  ]
 [1.01 1.   1.   ... 1.02 1.01 0.99]
 [1.02 1.01 0.99 ... 1.   1.   0.98]] -0.04 -0.06 -0.06 -0.03], Std: [[1.  0.98 -0.05 -0.06 -0.06 -0.03], Std: [[1.  0.99 -0.06 0.06 -0.06 -0.03], Std: [[1.  0.100 -0.05 -0.06 -0.06 -0.03]]

Annotation

The following example shows how to specify the interval of an input interval so that it exceeds the interval specified by the function.

Example 1: Output interval in minutes

var interval = 0;

The input interval must be a number that must be greater than or equal to:

0.01 = 0.01 * 12.22 * 24;

Example 2: Output interval in hours

var interval = 0;

The input interval must be a number that must be greater than or equal to:

0.01 = 0.01 * 10.25 * 15;

Example 3: Output interval in days

var interval = 0;

The input interval must be a number that mu

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Both `max_new_tokens` (=256) and `max_length`(=40) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)



[LLM Feedback] Suggest strategy for client with: Mean: [ 0.01  0.04  0.06 -0.05  0.04], Std: [[1.   0.97 1.01 ... 0.98 0.98 1.03]
 [1.01 1.01 1.02 ... 1.   1.02 0.97]
 [0.97 1.01 1.01 ... 0.98 0.99 1.  ]
 [1.01 1.   1.   ... 1.02 1.01 0.99]
 [1.02 1.01 0.99 ... 1.   1.   0.98]] -0.03 -0.08 0.03 -0.05], Std: [[1.  0.97 -0.03 -0.08 0.03 -0.05], Std: [[1.  0.97 -0.03 -0.08 0.03 -0.05], Std: [[1.  0.97 -0.03 -0.08 0.03 -0.05], Std: [[1.  0.97 -0.03 -0.08 0.03 -0.05], Std: [[1.  0.97 -0.03 -0.08 0.03 -0.05], Std: [[1.  0.97 -0.03 -0.08 0.03 -0.05], Std: [[1.  0.97 -0.03 -0.08 0.03 -0.05], Std: [[1.  0.97 -0.03 -0.08 0.03 -0.05], Std: [[1.  0.97 -0.03 -0.08 0.03 -0.05], Std: [[1.

Client: client_4


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Both `max_new_tokens` (=256) and `max_length`(=40) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)



[LLM Feedback] Suggest strategy for client with: Mean: [ 0.01  0.05  0.06 -0.05  0.03], Std: [[1.   0.97 1.01 ... 0.98 0.99 1.03]
 [1.01 1.01 1.02 ... 1.   1.03 0.97]
 [0.97 1.01 1.01 ... 0.98 0.99 1.  ]
 [1.01 1.   1.   ... 1.02 1.01 0.99]
 [1.02 1.01 0.99 ... 1.   0.99 0.98]] -0.15 -0.02 0.038 -0.08]] Mean: [[1.  0.85 -0.25 0.02 0.038 -0.08]] Mean: [[1.  0.81 -0.25 0.02 0.038 -0.08]] Mean: [[1.  0.81 -0.25 0.02 0.038 -0.08]] Mean: [[1.  0.86 -0.25 0.02 0.038 -0.08]] Mean: [[1.  0.85 -0.25 0.02 0.038 -0.08]] Mean: [[1.  0.86 -0.25 0.02 0.038 -0.08]] Mean: [[1.  0.86 -0.25 0.02 0.038 -0.08]] Mean: [[1.  0.86 -0.25 0.02 0.038 -0.08]] Mean: [[1.  0.86 -0.25 0.02 0.038 -0.08]] Mean: [[1.  0.86 -0.25 0
Test Accuracy: 10.07%


In [None]:
import numpy as np
from scipy.stats import truncnorm
import torch
import torchvision
import torchvision.transforms as transforms
from torch import nn, optim
from torch.utils.data import Subset, DataLoader
from torchvision.datasets import CIFAR10
from transformers import pipeline
from typing import Dict
import random

# Initialize LLM (GPT-2 for local use)
llm = pipeline("text-generation", model="gpt2")

# Define VAE model
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Flatten(),
            nn.Linear(32*32*3, 512),
            nn.ReLU(),
            nn.Linear(512, 128)
        )
        self.fc_mu = nn.Linear(128, 20)
        self.fc_logvar = nn.Linear(128, 20)
        self.decoder = nn.Sequential(
            nn.Linear(20, 512),
            nn.ReLU(),
            nn.Linear(512, 32*32*3),
            nn.Sigmoid(),
            nn.Unflatten(1, (3, 32, 32))
        )

    def encode(self, x):
        h = self.encoder(x)
        return self.fc_mu(h), self.fc_logvar(h)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

# Simple CNN for classification
class Classifier(nn.Module):
    def __init__(self):
        super(Classifier, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 16, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(16, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(8*8*32, 128),
            nn.ReLU(),
            nn.Linear(128, 10)
        )

    def forward(self, x):
        return self.model(x)

def get_distribution_info(vae: VAE, dataloader: DataLoader):
    vae.eval()
    latents = []
    with torch.no_grad():
        for x, _ in dataloader:
            x = x.to(next(vae.parameters()).device)
            mu, logvar = vae.encode(x)
            z = vae.reparameterize(mu, logvar)
            latents.append(z.cpu().numpy())
    latents = np.concatenate(latents, axis=0)
    return {
        "normal": {
            "mean": np.mean(latents, axis=0),
            "std": np.std(latents, axis=0)
        }
    }

def send_distribution_info(distribution_info: Dict) -> str:
    mean = np.round(distribution_info["normal"]["mean"][:5], 2)
    std = np.round(distribution_info["normal"]["std"][:5], 2)
    summary = f"Client observed mean {mean.tolist()} and std {std.tolist()}. Accuracy dropped below 70%."
    suggestion = llm(f"Suggest action: {summary}", max_length=50, num_return_sequences=1)
    suggestion_text = suggestion[0]['generated_text']
    print(f"\nüì£ LLM Suggestion:\n{suggestion_text}\n")
    return suggestion_text

def train_classifier(model: nn.Module, trainloader: DataLoader, valloader: DataLoader, epochs: int):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        for images, labels in trainloader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        val_correct = 0
        val_total = 0
        model.eval()
        with torch.no_grad():
            for val_images, val_labels in valloader:
                val_images, val_labels = val_images.to(device), val_labels.to(device)
                outputs = model(val_images)
                _, predicted = torch.max(outputs, 1)
                val_total += val_labels.size(0)
                val_correct += (predicted == val_labels).sum().item()

        train_loss = running_loss / len(trainloader)
        val_acc = 100 * val_correct / val_total
        print(f"Epoch [{epoch+1}/{epochs}], Training Loss: {train_loss:.3f}, Validation Accuracy: {val_acc:.2f}%")

# Dataset & Simulated Clients
transform = transforms.Compose([
    transforms.ToTensor()
])
dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
valset = CIFAR10(root='./data', train=False, download=True, transform=transform)
val_loader = DataLoader(valset, batch_size=100, shuffle=False)

client_indices = np.array_split(np.arange(len(dataset)), 2)
client_loaders = {
    f"client_{i}": DataLoader(Subset(dataset, idx), batch_size=64, shuffle=True)
    for i, idx in enumerate(client_indices)
}

# Training Loop
for client_id, loader in client_loaders.items():
    print(f"\nüöÄ Training for {client_id}")
    vae = VAE()
    classifier = Classifier()
    distribution_info = get_distribution_info(vae, loader)
    send_distribution_info(distribution_info)
    train_classifier(classifier, loader, val_loader, epochs=10)


Device set to use cpu



üöÄ Training for client_0


Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Both `max_new_tokens` (=256) and `max_length`(=50) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)



üì£ LLM Suggestion:
Suggest action: Client observed mean [-0.05000000074505806, -0.03999999910593033, 0.1899999976158142, -0.019999999552965164, 0.03999999910593033] and std [1.1200000047683716, 0.9599999785423279, 0.9700000286102295, 1.059999942779541, 0.8999999761581421]. Accuracy dropped below 70%.99959237593, 0.019999999552965164, 0.029999999552965164, 0.0299999999999552965164, 0.019999999552965164, 0.019999999552965164, 0.019999999552965164, 0.019999999552965164, 0.019999999552965164, 0.019999999552965164, 0.019999999552965164, 0.019999999552965164, 0.019999999552965164, 0.019999999552965164, 0.019999999552965164, 0.019999999552965164, 0.019999999552965164, 0.019999999552965164, 0.019999999552965164, 0.019999999552965164, 0.019999999552965164, 0.019999999552965164, 0.019999999552965164, 0.019999999552965164, 0.019999999552965164, 0.019999999552965164

Epoch [1/10], Training Loss: 1.748, Validation Accuracy: 47.31%
Epoch [2/10], Training Loss: 1.391, Validation Accuracy: 52.90%
E

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Both `max_new_tokens` (=256) and `max_length`(=50) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)



üì£ LLM Suggestion:
Suggest action: Client observed mean [-0.07000000029802322, -0.11999999731779099, -0.019999999552965164, -0.019999999552965164, -0.11999999731779099] and std [1.0199999809265137, 1.0099999904632568, 1.0099999904632568, 1.0199999809265137, 0.9700000286102295]. Accuracy dropped below 70%.0.0100000001131244493601, -0.019999999552965164, -0.0100000001131244493601, -0.0100000001131244493601, -0.0100000001131244493601)

Note: The above table assumes that all of the keys are in the order specified.

You can also use the KeyPacket method.

See the KeyPacket documentation for more details.

If you still have questions about the results of this query, be sure to read the KeyPacket documentation.

Get the data

You can send an HTTP request using the following method:

GET /data/query/query_result?type=text

If you don't specify a data type, the output is the same as the previous query.

You can get the results by calling the KeyPacket method.

If you need to send data to the

In [None]:
import numpy as np
from scipy.stats import truncnorm
import torch
import torchvision
import torchvision.transforms as transforms
from torch import nn
from torch.utils.data import DataLoader, Subset
from torchvision.datasets import CIFAR10
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score, accuracy_score
from transformers import pipeline
import random

# Define the VAE model
class VAE(nn.Module):
    def __init__(self, latent_dim=128):
        super(VAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, 4, 2, 1), nn.ReLU(),
            nn.Conv2d(32, 64, 4, 2, 1), nn.ReLU(),
            nn.Flatten()
        )
        self.fc_mu = nn.Linear(64 * 8 * 8, latent_dim)
        self.fc_logvar = nn.Linear(64 * 8 * 8, latent_dim)
        self.fc_decode = nn.Linear(latent_dim, 64 * 8 * 8)
        self.decoder = nn.Sequential(
            nn.Unflatten(1, (64, 8, 8)),
            nn.ConvTranspose2d(64, 32, 4, 2, 1), nn.ReLU(),
            nn.ConvTranspose2d(32, 3, 4, 2, 1), nn.Sigmoid()
        )

    def encode(self, x):
        h = self.encoder(x)
        return self.fc_mu(h), self.fc_logvar(h)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h = self.fc_decode(z)
        return self.decoder(h)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

# CNN Classifier
class CNNClassifier(nn.Module):
    def __init__(self):
        super(CNNClassifier, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Flatten(), nn.Linear(64 * 8 * 8, 256), nn.ReLU(),
            nn.Linear(256, 10)
        )

    def forward(self, x):
        return self.net(x)

# Federated training
transform = transforms.Compose([transforms.ToTensor()])
trainset = CIFAR10(root='./data', train=True, download=True, transform=transform)
testset = CIFAR10(root='./data', train=False, download=True, transform=transform)

num_clients = 5
client_indices = np.array_split(np.arange(len(trainset)), num_clients)
trainloaders = {
    f'client_{i}': DataLoader(Subset(trainset, client_indices[i]), batch_size=64, shuffle=True)
    for i in range(num_clients)
}
valloader = DataLoader(testset, batch_size=64, shuffle=False)

vae = VAE().cuda()
net = CNNClassifier().cuda()
optimizer_vae = torch.optim.Adam(vae.parameters(), lr=1e-3)
optimizer_net = torch.optim.Adam(net.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
llm = pipeline("text-generation", model="gpt2")

def vae_train(vae, dataloader, epochs=1):
    vae.train()
    for epoch in range(epochs):
        for x, _ in dataloader:
            x = x.cuda()
            x_recon, mu, logvar = vae(x)
            recon_loss = nn.functional.mse_loss(x_recon, x)
            kl_loss = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
            loss = recon_loss + kl_loss
            optimizer_vae.zero_grad()
            loss.backward()
            optimizer_vae.step()

def train(net, dataloader, epochs):
    net.train()
    for epoch in range(epochs):
        total_loss = 0
        for x, y in dataloader:
            x, y = x.cuda(), y.cuda()
            output = net(x)
            loss = criterion(output, y)
            optimizer_net.zero_grad()
            loss.backward()
            optimizer_net.step()
            total_loss += loss.item()
        acc = evaluate(net, valloader)
        print(f"Epoch [{epoch+1}/{epochs}], Training Loss: {total_loss / len(dataloader):.3f}, Validation Accuracy: {acc:.2f}%")

def evaluate(net, dataloader):
    net.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in dataloader:
            x, y = x.cuda(), y.cuda()
            output = net(x)
            pred = output.argmax(dim=1)
            correct += (pred == y).sum().item()
            total += y.size(0)
    return 100. * correct / total

def final_metrics(net, dataloader):
    net.eval()
    y_true, y_pred = [], []
    with torch.no_grad():
        for x, y in dataloader:
            x = x.cuda()
            output = net(x)
            pred = output.argmax(dim=1).cpu()
            y_true.extend(y.numpy())
            y_pred.extend(pred.numpy())
    acc = accuracy_score(y_true, y_pred)
    prec = precision_score(y_true, y_pred, average='macro')
    rec = recall_score(y_true, y_pred, average='macro')
    f1 = f1_score(y_true, y_pred, average='macro')
    print("\nüìä Final Metrics:")
    print(f"Accuracy: {acc * 100:.2f}%")
    print(f"Precision: {prec:.4f}, Recall: {rec:.4f}, F1-Score: {f1:.4f}")

def federated_train():
    for client_id, loader in trainloaders.items():
        print(f"\nüß† Training {client_id}")
        vae_train(vae, loader, epochs=1)
    for epoch in range(500):
        print(f"\nüåç Global Epoch {epoch+1}/500")
        for client_id, loader in trainloaders.items():
            train(net, loader, epochs=1)
    final_metrics(net, valloader)

federated_train()



RuntimeError: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx

In [None]:
# kl_feddis_vit_cifar10/main.py

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from transformers import ViTForImageClassification, ViTModel
from torch.utils.data import DataLoader, Subset
import numpy as np
import copy
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import random

# --- 1. VAE Module ---
class VAE(nn.Module):
    def __init__(self, input_dim=768, latent_dim=64):
        super(VAE, self).__init__()
        self.fc1 = nn.Linear(input_dim, 512)
        self.fc21 = nn.Linear(512, latent_dim)  # mu
        self.fc22 = nn.Linear(512, latent_dim)  # logvar
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 512),
            nn.ReLU(),
            nn.Linear(512, input_dim)
        )

    def encode(self, x):
        h = F.relu(self.fc1(x))
        return self.fc21(h), self.fc22(h)

    def reparametrize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparametrize(mu, logvar)
        return self.decoder(z), mu, logvar

# --- 2. Data Partitioning ---
def non_iid_partition(dataset, num_clients):
    label_indices = {i: [] for i in range(10)}
    for idx, (_, label) in enumerate(dataset):
        label_indices[label].append(idx)

    client_data = [[] for _ in range(num_clients)]
    for label, indices in label_indices.items():
        random.shuffle(indices)
        splits = np.array_split(indices, num_clients)
        for client_id, split in enumerate(splits):
            client_data[client_id].extend(split)
    return [Subset(dataset, idxs) for idxs in client_data]

# --- 3. Training on Client ---
def train_client(model, vae, dataloader, optimizer, device):
    model.train()
    vae.train()
    total_kl = 0.0

    for images, labels in dataloader:
        images, labels = images.to(device), labels.to(device)
        # Get embeddings from the ViT model before the classifier head
        outputs = model(images, output_hidden_states=True) # Explicitly request hidden states
        logits = outputs.logits
        embeddings = outputs.hidden_states[-1][:, 0, :] # Get the [CLS] token embedding from the last layer

        recon, mu, logvar = vae(embeddings)
        kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / len(images)

        loss = nn.CrossEntropyLoss()(logits, labels) + kl * 1e-3
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_kl += kl.item()

    return model.state_dict(), mu.detach(), logvar.detach(), total_kl

# --- 4. Aggregation ---
def kl_weighted_aggregate(client_models, mus, logvars):
    global_model = copy.deepcopy(client_models[0])
    global_mu = torch.stack(mus).mean(dim=0)
    global_logvar = torch.stack(logvars).mean(dim=0)

    kl_weights = []
    for mu, logvar in zip(mus, logvars):
        # Calculate KL divergence between client distribution and global distribution
        kl = torch.sum(0.5 * (global_logvar - logvar +
                 (torch.exp(logvar) + (mu - global_mu).pow(2)) / torch.exp(global_logvar) - 1))
        # Use a small epsilon to prevent division by zero and encourage exploration
        weight = 1 / (1 + kl.item() + 1e-9)
        kl_weights.append(weight)

    # Normalize weights
    kl_weights = [w / sum(kl_weights) for w in kl_weights]

    for key in global_model:
        # Aggregate based on KL weights, excluding the classifier head
        if 'classifier' not in key:
             global_model[key] = sum(w * client_model[key] for w, client_model in zip(kl_weights, client_models))

    return global_model

# --- 5. Evaluation ---
def evaluate(model, test_loader, device):
    model.eval()
    y_true, y_pred = [], []
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images).logits
            preds = outputs.argmax(dim=1)
            y_true.extend(labels.cpu().numpy())
            y_pred.extend(preds.cpu().numpy())
    acc = accuracy_score(y_true, y_pred)
    p, r, f, _ = precision_recall_fscore_support(y_true, y_pred, average='macro', zero_division=0)
    return acc, p, r, f

# --- 6. Main Federated Training Loop ---
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Add normalization for ViT
    ])

    cifar_train = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
    cifar_test = datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)
    test_loader = DataLoader(cifar_test, batch_size=128, shuffle=False)

    num_clients = 5
    clients_data = non_iid_partition(cifar_train, num_clients)

    # Load the pretrained ViT model without the classification head
    global_model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224', num_labels=10, ignore_mismatched_sizes=True).to(device)
    global_vae = VAE().to(device)

    rounds = 5
    for rnd in range(rounds):
        client_models, mus, logvars = [], [], []
        print(f"\n--- Federated Learning Round {rnd+1}/{rounds} ---")

        for client_id in range(num_clients):
            print(f"Training Client {client_id + 1}")
            model = copy.deepcopy(global_model).to(device)
            vae = copy.deepcopy(global_vae).to(device)
            optimizer = torch.optim.Adam(list(model.parameters()) + list(vae.parameters()), lr=1e-4)

            dataloader = DataLoader(clients_data[client_id], batch_size=64, shuffle=True)
            state_dict, mu, logvar, avg_kl = train_client(model, vae, dataloader, optimizer, device)

            client_models.append(state_dict)
            mus.append(mu)
            logvars.append(logvar)
            print(f"Client {client_id + 1} finished training. Average KL: {avg_kl:.4f}")


        print("Aggregating client models...")
        global_model.load_state_dict(kl_weighted_aggregate(client_models, mus, logvars))
        # After aggregation, re-initialize the VAE for the next round with potentially updated global distribution
        global_vae = VAE().to(device) # Re-initialize VAE for the next round


        acc, p, r, f = evaluate(global_model, test_loader, device)
        print(f"\nRound {rnd+1} Global Model Evaluation:")
        print(f"Accuracy: {acc*100:.2f}%, Precision: {p*100:.2f}%, Recall: {r*100:.2f}%, F1: {f*100:.2f}%")

if __name__ == "__main__":
    main()

Using device: cpu
