In [2]:
pip install gradio






[notice] A new release of pip is available: 23.3.1 -> 24.3.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [3]:
import torch.nn as nn
class CNN(nn.Module):
    def __init__(self, num_classes):
        super(CNN, self).__init__()
        # Define layers
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.fc1 = nn.Linear(32 * 64 * 64, 128)  # Adjust based on output size from conv layers
        self.fc2 = nn.Linear(128, num_classes)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = self.pool(nn.ReLU()(self.conv1(x)))  # Conv Layer 1
        x = self.pool((nn.ReLU()(self.conv2(x))))  # Conv Layer 2
        x = x.view(-1, 32 * 64 * 64)  # Flatten for fully connected layer
        x = nn.ReLU()(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x) 
        return x

In [7]:
import torch
import copy

# Assuming your CNN class definition is available here
# Define a function to load models from saved files
def load_client_models(num_clients):
    local_models = []
    for i in range(num_clients):
        model = CNN(num_classes=10)  # Replace `10` with your actual number of classes
        model.load_state_dict(torch.load(f"client models/client_{i+1}.pth"))
        local_models.append(model)
    return local_models

# Load models trained in the last round, assuming 5 clients and last round being the 5th
local_models = load_client_models(num_clients=5)

In [4]:
import gradio as gr
import torch
import torchvision.transforms as transforms
from PIL import Image, UnidentifiedImageError
import os

# Define transformations to resize and normalize uploaded images
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Provided label mapping
label_names = ['cat', 'dog', 'bird', 'fish', 'car', 'aircraft', 'flower', 'truck', 'parachute', 'mushroom']
label_to_index = {label: idx for idx, label in enumerate(label_names)}

# Evaluation function
def evaluate_custom_batch(images, labels):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    preprocessed_images = torch.stack([transform(img) for img in images]).to(device)
    
    client_accuracies = {}
    processed_info = []

    for idx, client_model in enumerate(local_models):
        client_model.eval()
        client_model.to(device)
        with torch.no_grad():
            outputs = client_model(preprocessed_images)
            _, predicted = torch.max(outputs, 1)

        # Convert predictions to list
        predictions = predicted.cpu().numpy()

        # Calculate accuracy
        correct = (predicted.cpu() == torch.tensor(labels)).sum().item()
        accuracy = correct / len(labels) if len(labels) > 0 else 0
        client_accuracies[f"Client {idx + 1} Accuracy"] = f"{accuracy * 100:.2f}%"
        
        # Create processed info for each client
        client_info = []
        for i in range(len(images)):
            client_info.append({
                "filename": f"Image {i + 1}",  # Placeholder for filenames if needed
                "label": int(labels[i]),  # Ensure label is an integer
                "prediction": int(predictions[i]),  # Ensure prediction is an integer
                "label_name": label_names[labels[i]],  # Optional: Include the string name of the label
                "prediction_name": label_names[predictions[i]]  # Optional: Include the string name of the prediction
            })

        # Add client information to processed_info
        processed_info.append({
            "client": f"Client {idx + 1}",
            "info": client_info
        })

    return {
        "Client Accuracies": client_accuracies,
        "Processed Information": processed_info
    }

# Gradio Interface function
def gradio_interface(files):
    images = []
    labels = []

    # Load images and extract labels from filenames
    for file in files:
        try:
            img = Image.open(file.name).convert("RGB")
            images.append(img)

            # Extract label from the filename (string before underscore)
            label_str = os.path.basename(file.name).split('_')[0]
            if label_str in label_to_index:
                labels.append(label_to_index[label_str])
            else:
                raise ValueError(f"Label '{label_str}' not recognized.")
        except (UnidentifiedImageError, ValueError) as e:
            print(f"Error processing file {file.name}: {e}")
            continue

    if not images:
        return {"error": "No valid images were uploaded. Please upload image files only."}

    if len(labels) != len(images):
        return {"error": "The number of labels does not match the number of images."}

    return evaluate_custom_batch(images, labels)

# Launch the Gradio UI
gr.Interface(
    fn=gradio_interface,
    inputs=gr.File(label="Upload Images", type="file", file_count="multiple"),
    outputs="json",
    title="Federated Learning Client Evaluation",
    description="Upload a batch of images to evaluate client model accuracy based on the labels extracted from filenames.",
    allow_flagging="never",
    live=False
).launch()


Running on local URL:  http://127.0.0.1:7860

Thanks for being a Gradio user! If you have questions or feedback, please join our Discord server and chat with us: https://discord.gg/feTf9x3ZSB

To create a public link, set `share=True` in `launch()`.




IMPORTANT: You are using gradio version 3.45.2, however version 4.44.1 is available, please upgrade.
--------
