In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset, TensorDataset
import numpy as np
from sklearn import datasets
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from typing import Sequence, Tuple
import pandas as pd
from collections import OrderedDict, defaultdict
import torch.nn.functional as F

model_equations = []

def fed_model(testimages):
    # Load the Iris dataset
    iris = datasets.load_iris()
    X = iris.data
    y = iris.target

    # Standardize the features
    scaler = StandardScaler()
    X = scaler.fit_transform(X)

    # Split the dataset into training and testing sets
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

    # Convert to PyTorch tensors
    X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
    y_train_tensor = torch.tensor(y_train, dtype=torch.long)
    X_test_tensor = torch.tensor(X_test, dtype=torch.float32)
    y_test_tensor = torch.tensor(y_test, dtype=torch.long)

    # Create TensorDataset and DataLoader
    train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
    test_dataset = TensorDataset(X_test_tensor, y_test_tensor)

    # Number of clients
    n_clients = 3

    # Split the training data into n_clients parts
    indices = np.arange(len(train_dataset))
    np.random.shuffle(indices)
    split_indices = np.array_split(indices, n_clients)

    # Create data loaders for each client
    client_loaders = []
    batch_size = 16
    for client_indices in split_indices:
        client_subset = Subset(train_dataset, client_indices)
        client_loader = DataLoader(client_subset, batch_size=batch_size, shuffle=True)
        client_loaders.append(client_loader)

    # Define a simple feedforward neural network model
    class SimpleNN(nn.Module):
        def __init__(self):
            super(SimpleNN, self).__init__()
            self.fc1 = nn.Linear(4, 50)  # Adjust input size to match the number of features (4 for Iris)
            self.fc2 = nn.Linear(50, 20)
            self.fc3 = nn.Linear(20, 3)  # Adjust output size to match the number of classes (3 for Iris)
            self.relu = nn.ReLU()

        def forward(self, x):
            x = self.relu(self.fc1(x))
            x = self.relu(self.fc2(x))
            x = self.fc3(x)
            return x

    # Train the model on each client's data and save the weights
    client_models = []
    epochs = 20
    criterion = nn.CrossEntropyLoss()

    for i, loader in enumerate(client_loaders):
        model = SimpleNN()
        optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

        # Training loop
        for epoch in range(epochs):
            running_loss = 0.0
            for inputs, labels in loader:
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                running_loss += loss.item()
            # print(f'Client {i+1}, Epoch {epoch+1}, Loss: {running_loss / len(loader)}')

        # Save the model weights
        torch.save(model.state_dict(), f'client_{i+1}_model.pth')
        client_models.append(model.state_dict())

    # Federated averaging
    sums = defaultdict(int)
    count = len(client_models)
    for od in client_models:
        for key, value in od.items():
            sums[key] += value

    # Calculate the average for each key
    averages = {key: value / count for key, value in sums.items()}

    # Convert the averages to an OrderedDict (optional)
    average_ordereddict = OrderedDict(averages)

    model = SimpleNN()
    model.load_state_dict(average_ordereddict)
    testimages = torch.tensor(testimages, dtype=torch.float32)
    y_test = model(testimages)
    return y_test

# Configuration class
class Config:
    def __init__(self, dropout=0.5, learning_rate=0.001, num_epochs=50, batch_size=32):
        self.dropout = dropout
        self.learning_rate = learning_rate
        self.num_epochs = num_epochs
        self.batch_size = batch_size

# Base Model class (assuming you have this implemented)
class Model(nn.Module):
    def __init__(self, config, name):
        super(Model, self).__init__()
        self.config = config
        self.name = name

# FeatureNN class (assuming you have this implemented)
class FeatureNN(nn.Module):
    def __init__(self, config, name, input_shape, num_units, feature_num):
        super(FeatureNN, self).__init__()
        self.config = config
        self.name = name
        self.input_shape = input_shape
        self.num_units = num_units
        self.feature_num = feature_num
        self.fc = nn.Linear(input_shape, num_units)

    def forward(self, x):
        x = self.fc(x)
        x = F.relu(x)
        return x

# NAM model definition
class NAM(Model):
    def __init__(self, config, name, *, num_inputs: int, num_units: int) -> None:
        super(NAM, self).__init__(config, name)
        self._num_inputs = num_inputs
        self.dropout = nn.Dropout(p=self.config.dropout)

        if isinstance(num_units, list):
            assert len(num_units) == num_inputs
            self._num_units = num_units
        elif isinstance(num_units, int):
            self._num_units = [num_units for _ in range(self._num_inputs)]

        self.feature_nns = nn.ModuleList([
            FeatureNN(config=config, name=f'FeatureNN_{i}', input_shape=1, num_units=self._num_units[i], feature_num=i)
            for i in range(num_inputs)
        ])

        self.output_layer = nn.Linear(sum(self._num_units), 3)  # 3 classes for Iris dataset
        self._bias = torch.nn.Parameter(data=torch.zeros(1))

    def calc_outputs(self, inputs: torch.Tensor) -> Sequence[torch.Tensor]:
        return [self.feature_nns[i](inputs[:, i:i+1]) for i in range(self._num_inputs)]

    def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        individual_outputs = self.calc_outputs(inputs)
        conc_out = torch.cat(individual_outputs, dim=-1)
        dropout_out = self.dropout(conc_out)
        out = self.output_layer(dropout_out)
        return out, dropout_out

    def print_model_equation(self, feature_names):
        equation_terms = []
        feature_contributions = {}
        for i, fnn in enumerate(self.feature_nns):
            coefficients = fnn.fc.weight.data.flatten().tolist()
            intercepts = fnn.fc.bias.data.tolist()
            term = " + ".join([f"({coeff:.3f} * x_{feature_names[i]} + {intercept:.3f})" for coeff, intercept in zip(coefficients, intercepts)])
            equation_terms.append(term)
            feature_contributions[feature_names[i]] = sum(abs(c) for c in coefficients)
        equation = " + ".join(equation_terms) + f" + bias ({self._bias.item():.3f})"
        print(f"Model Equation: y = {equation}")
        model_equations.append(equation)

        # Determine feature interpretability based on coefficients
        interpretability = sorted(feature_contributions.items(), key=lambda x: x[1], reverse=True)
        print("\nFeature Contributions:")
        for feature, contribution in interpretability:
            print(f"{feature}: {contribution:.3f}")

        return interpretability[0][0]  # Return the feature with the highest contribution

n_clients = 3
# Load Iris dataset
iris = datasets.load_iris()
X = iris.data
y = iris.target
feature_columns = iris.feature_names

# Standardize the features
scaler = StandardScaler()
X = scaler.fit_transform(X)

# Split the data into n_clients
indices = np.arange(len(X))
np.random.shuffle(indices)
split_indices = np.array_split(indices, n_clients)
client_data = [(X[indices], y[indices]) for indices in split_indices]

for i in range(n_clients):
    X, y = client_data[i]

    # Split the data
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

    # Convert to PyTorch tensors
    X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
    y_train_tensor = torch.tensor(y_train, dtype=torch.long)
    X_test_tensor = torch.tensor(X_test, dtype=torch.float32)
    y_test_tensor = torch.tensor(y_test, dtype=torch.long)

    # Define the config
    config = Config(dropout=0.5, learning_rate=0.001, num_epochs=50, batch_size=32)

    # Instantiate the NAM model
    num_inputs = len(feature_columns)  # Number of features
    num_units = 10  # Number of units in the hidden layer
    nam_model = NAM(config=config, name='NAM_Model', num_inputs=num_inputs, num_units=num_units)

    # Training function
    def train(model, X_train, y_train, config):
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
        model.train()
        for epoch in range(config.num_epochs):
            outputs = fed_model(X_test_tensor)
            optimizer.zero_grad()
            loss = criterion(outputs, y_test_tensor)
            loss.backward()
            optimizer.step()
            if (epoch + 1) % 10 == 0:
                print(f'Epoch [{epoch + 1}/{config.num_epochs}], Loss: {loss.item():.4f}')
        return model

    # Evaluation function
    def evaluate(model, X_test, y_test):
        model.eval()
        with torch.no_grad():
            outputs, _ = model(X_test_tensor)
            _, predicted = torch.max(outputs, 1)
            accuracy = (predicted == y_test).sum().item() / y_test.size(0)
            print(f'Accuracy: {accuracy * 100:.2f}%')

    # Train the model
    trained_model = train(nam_model, X_train_tensor, y_train_tensor, config)

    # Evaluate the model
    evaluate(trained_model, X_test_tensor, y_test_tensor)

    # Print the model equation and get the most contributing feature
    most_contributing_feature = trained_model.print_model_equation(feature_columns)
    print(f"\nMost contributing feature for client's output {i}: {most_contributing_feature}")


  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)


Epoch [10/50], Loss: 0.9898


  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)


Epoch [20/50], Loss: 0.9417


  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)


Epoch [30/50], Loss: 1.0564


  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)


Epoch [40/50], Loss: 0.8475


  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)


Epoch [50/50], Loss: 1.0457
Accuracy: 26.67%
Model Equation: y = (-0.357 * x_sepal length (cm) + 0.971) + (-0.778 * x_sepal length (cm) + 0.007) + (0.519 * x_sepal length (cm) + 0.852) + (0.661 * x_sepal length (cm) + -0.670) + (-0.432 * x_sepal length (cm) + 0.580) + (-0.022 * x_sepal length (cm) + 0.341) + (0.456 * x_sepal length (cm) + 0.956) + (-0.868 * x_sepal length (cm) + -0.559) + (-0.107 * x_sepal length (cm) + -0.425) + (0.010 * x_sepal length (cm) + 0.657) + (-0.377 * x_sepal width (cm) + -0.399) + (-0.615 * x_sepal width (cm) + 0.452) + (0.985 * x_sepal width (cm) + -0.292) + (0.249 * x_sepal width (cm) + 0.726) + (-0.145 * x_sepal width (cm) + -0.657) + (-0.290 * x_sepal width (cm) + 0.167) + (-0.879 * x_sepal width (cm) + 0.642) + (-0.559 * x_sepal width (cm) + -0.269) + (0.643 * x_sepal width (cm) + 0.735) + (0.013 * x_sepal width (cm) + 0.303) + (0.861 * x_petal length (cm) + 0.502) + (0.466 * x_petal length (cm) + -0.854) + (0.697 * x_petal length (cm) + 0.162) + (0.43

  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)


Epoch [10/50], Loss: 0.9156


  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)


Epoch [20/50], Loss: 0.9606


  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)


Epoch [30/50], Loss: 0.9803


  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)


Epoch [40/50], Loss: 1.0143


  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)


Epoch [50/50], Loss: 0.9843
Accuracy: 26.67%
Model Equation: y = (0.379 * x_sepal length (cm) + 0.686) + (-0.387 * x_sepal length (cm) + -0.554) + (0.728 * x_sepal length (cm) + 0.455) + (-0.480 * x_sepal length (cm) + -0.706) + (0.673 * x_sepal length (cm) + 0.145) + (-0.348 * x_sepal length (cm) + -0.659) + (0.105 * x_sepal length (cm) + 0.500) + (0.090 * x_sepal length (cm) + 0.269) + (0.287 * x_sepal length (cm) + 0.576) + (-0.419 * x_sepal length (cm) + -0.608) + (0.961 * x_sepal width (cm) + -0.450) + (0.592 * x_sepal width (cm) + -0.219) + (0.073 * x_sepal width (cm) + -0.689) + (0.367 * x_sepal width (cm) + 0.776) + (-0.428 * x_sepal width (cm) + -0.851) + (-0.797 * x_sepal width (cm) + -0.281) + (0.333 * x_sepal width (cm) + 0.823) + (-0.657 * x_sepal width (cm) + 0.846) + (-0.363 * x_sepal width (cm) + -0.716) + (-0.585 * x_sepal width (cm) + -0.356) + (-0.045 * x_petal length (cm) + -0.050) + (0.077 * x_petal length (cm) + -0.362) + (0.816 * x_petal length (cm) + -0.846) + (

  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)


Epoch [10/50], Loss: 0.8763


  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)


Epoch [20/50], Loss: 0.8842


  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)


Epoch [30/50], Loss: 0.9541


  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)


Epoch [40/50], Loss: 0.7706


  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)
  testimages = torch.tensor(testimages, dtype=torch.float32)


Epoch [50/50], Loss: 0.8331
Accuracy: 33.33%
Model Equation: y = (0.664 * x_sepal length (cm) + -0.752) + (0.276 * x_sepal length (cm) + -0.238) + (-0.852 * x_sepal length (cm) + -0.019) + (0.570 * x_sepal length (cm) + 0.910) + (0.112 * x_sepal length (cm) + -0.555) + (-0.098 * x_sepal length (cm) + 0.702) + (-0.506 * x_sepal length (cm) + 0.895) + (-0.568 * x_sepal length (cm) + -0.537) + (-0.291 * x_sepal length (cm) + 0.135) + (-0.233 * x_sepal length (cm) + 0.329) + (-0.240 * x_sepal width (cm) + -0.795) + (0.736 * x_sepal width (cm) + 0.197) + (0.087 * x_sepal width (cm) + 0.684) + (-0.140 * x_sepal width (cm) + 0.729) + (-0.751 * x_sepal width (cm) + -0.232) + (-0.256 * x_sepal width (cm) + 0.162) + (0.565 * x_sepal width (cm) + -0.868) + (0.504 * x_sepal width (cm) + -0.151) + (-0.364 * x_sepal width (cm) + 0.818) + (0.842 * x_sepal width (cm) + 0.136) + (-0.181 * x_petal length (cm) + -0.623) + (0.346 * x_petal length (cm) + 0.310) + (0.921 * x_petal length (cm) + 0.640) + (-0

  testimages = torch.tensor(testimages, dtype=torch.float32)


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import torch.nn.functional as F
import numpy as np

# Define the transformation for the CIFAR-10 dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load the CIFAR-10 dataset
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

trainloader = DataLoader(trainset, batch_size=32, shuffle=True)
testloader = DataLoader(testset, batch_size=32, shuffle=False)

# Define a simple CNN model for image classification
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(32 * 8 * 8, 128)
        self.fc2 = nn.Linear(128, 10)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.pool(x)
        x = self.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(-1, 32 * 8 * 8)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Train the model on each client's data and save the weights
epochs = 20
criterion = nn.CrossEntropyLoss()

def train_model(model, trainloader, epochs, criterion):
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in trainloader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
    return model

# Train the model
model = SimpleCNN()
trained_model = train_model(model, trainloader, epochs, criterion)

# Function to get the top 30% contributing pixels
def get_top_contributing_pixels(model, data, percentage=0.3):
    model.eval()
    data.requires_grad_()
    output = model(data)
    pred_class = output.argmax(dim=1).item()

    model.zero_grad()
    output[0, pred_class].backward()
    grad = data.grad.abs().squeeze().detach().numpy()

    # Flatten the gradient array and get the threshold value
    flattened_grad = grad.flatten()
    threshold = np.percentile(flattened_grad, 100 - percentage * 100)

    # Create a mask of the top contributing pixels
    top_pixels_mask = grad >= threshold

    return top_pixels_mask

# Get a test image
test_images, _ = next(iter(testloader))
test_image = test_images[0].unsqueeze(0)

# Get the top 30% contributing pixels
top_pixels_mask = get_top_contributing_pixels(trained_model, test_image)
highlighted_image = test_image.squeeze().numpy()

# Convert the mask to coordinates
mask_2d = top_pixels_mask
highlighted_image[mask_2d == 0] = 0  # Mask non-top pixels

# Plot the original and highlighted image
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.title("Original Image")
plt.imshow(np.transpose(test_image.squeeze().numpy(), (1, 2, 0)))
plt.axis('off')

plt.subplot(1, 2, 2)
plt.title("Highlighted Image")
plt.imshow(np.transpose(highlighted_image, (1, 2, 0)))
plt.axis('off')
plt.show()

# Configuration class
class Config:
    def __init__(self, dropout=0.5, learning_rate=0.001, num_epochs=50, batch_size=32):
        self.dropout = dropout
        self.learning_rate = learning_rate
        self.num_epochs = num_epochs
        self.batch_size = batch_size

# Base Model class (assuming you have this implemented)
class Model(nn.Module):
    def __init__(self, config, name):
        super(Model, self).__init__()
        self.config = config
        self.name = name

# FeatureNN class (assuming you have this implemented)
class FeatureNN(nn.Module):
    def __init__(self, config, name, input_shape, num_units, feature_num):
        super(FeatureNN, self).__init__()
        self.config = config
        self.name = name
        self.input_shape = input_shape
        self.num_units = num_units
        self.feature_num = feature_num
        self.fc = nn.Linear(input_shape, num_units)

    def forward(self, x):
        x = self.fc(x)
        x = F.relu(x)
        return x

# NAM model definition
class NAM(Model):
    def __init__(self, config, name, *, num_inputs: int, num_units: int) -> None:
        super(NAM, self).__init__(config, name)
        self._num_inputs = num_inputs
        self.dropout = nn.Dropout(p=self.config.dropout)

        if isinstance(num_units, list):
            assert len(num_units) == num_inputs
            self._num_units = num_units
        elif isinstance(num_units, int):
            self._num_units = [num_units for _ in range(self._num_inputs)]

        self.feature_nns = nn.ModuleList([
            FeatureNN(config=config, name=f'FeatureNN_{i}', input_shape=1, num_units=self._num_units[i], feature_num=i)
            for i in range(num_inputs)
        ])

        self.output_layer = nn.Linear(sum(self._num_units), 10)  # 10 classes for CIFAR-10 dataset
        self._bias = torch.nn.Parameter(data=torch.zeros(1))

    def calc_outputs(self, inputs: torch.Tensor) -> torch.Tensor:
        return [self.feature_nns[i](inputs[:, i:i+1]) for i in range(self._num_inputs)]

    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        individual_outputs = self.calc_outputs(inputs)
        conc_out = torch.cat(individual_outputs, dim=-1)
        dropout_out = self.dropout(conc_out)
        out = self.output_layer(dropout_out)
        return out, dropout_out

    def print_model_equation(self, feature_names):
        equation_terms = []
        feature_contributions = {}
        for i, fnn in enumerate(self.feature_nns):
            coefficients = fnn.fc.weight.data.flatten().tolist()
            intercepts = fnn.fc.bias.data.tolist()
            term = " + ".join([f"({coeff:.3f} * x_{feature_names[i]} + {intercept:.3f})" for coeff, intercept in zip(coefficients, intercepts)])
            equation_terms.append(term)
            feature_contributions[feature_names[i]] = sum(abs(c) for c in coefficients)
        equation = " + ".join(equation_terms) + f" + bias ({self._bias.item():.3f})"
        print(f"Model Equation: y = {equation}")

        # Determine feature interpretability based on coefficients
        interpretability = sorted(feature_contributions.items(), key=lambda x: x[1], reverse=True)
        print("\nFeature Contributions:")
        for feature, contribution in interpretability:
            print(f"{feature}: {contribution:.3f}")

        return interpretability[0][0]  # Return the feature with the highest contribution

# Define the config
config = Config(dropout=0.5, learning_rate=0.001, num_epochs=50, batch_size=32)

# Instantiate the NAM model
num_inputs = 32 * 32 * 3  # Number of features for CIFAR-10 images (32x32 and 3 channels)
num_units = 10  # Number of units in the hidden layer
nam_model = NAM(config=config, name='NAM_Model', num_inputs=num_inputs, num_units=num_units)

# Training function for NAM model
def train_nam_model(model, trainloader, config):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
    model.train()
    for epoch in range(config.num_epochs):
        running_loss = 0.0
        for inputs, labels in trainloader:
            optimizer.zero_grad()
            outputs, _ = model(inputs.view(inputs.size(0), -1))  # Flatten the inputs
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        if (epoch + 1) % 10 == 0:
            print(f'Epoch [{epoch + 1}/{config.num_epochs}], Loss: {running_loss / len(trainloader):.4f}')
    return model

# Train the NAM model
trained_nam_model = train_nam_model(nam_model, trainloader, config)

# Evaluate the NAM model
def evaluate_nam_model(model, testloader):
    model.eval()
    correct = 0.0
    total = 0.0
    with torch.no_grad():
        for inputs, labels in testloader:
            outputs, _ = model(inputs.view(inputs.size(0), -1))  # Flatten the inputs
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = correct / total
    print(f'Accuracy: {accuracy * 100:.2f}%')

evaluate_nam_model(trained_nam_model, testloader)

# Print the model equation and get the most contributing feature
feature_names = [f'pixel_{i}' for i in range(num_inputs)]
most_contributing_feature = trained_nam_model.print_model_equation(feature_names)
print(f"\nMost contributing feature: {most_contributing_feature}")


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:03<00:00, 43103711.19it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


RuntimeError: Can't call numpy() on Tensor that requires grad. Use tensor.detach().numpy() instead.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import torch.nn.functional as F
import numpy as np
# Define the transformation for the image dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Load the MNIST dataset
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

trainloader = DataLoader(trainset, batch_size=32, shuffle=True)
testloader = DataLoader(testset, batch_size=32, shuffle=False)

# Define a simple CNN model for image classification
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(32 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.pool(x)
        x = self.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(-1, 32 * 7 * 7)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Train the model on each client's data and save the weights
epochs = 20
criterion = nn.CrossEntropyLoss()

def train_model(model, trainloader, epochs, criterion):
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in trainloader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
    return model

# Train the model
model = SimpleCNN()
trained_model = train_model(model, trainloader, epochs, criterion)

# Function to get the top 30% contributing pixels
def get_top_contributing_pixels(model, data, percentage=0.3):
    model.eval()
    data.requires_grad_()
    output = model(data)
    pred_class = output.argmax(dim=1).item()

    model.zero_grad()
    output[0, pred_class].backward()
    grad = data.grad.abs().squeeze().detach().numpy()

    # Flatten the gradient array and get the threshold value
    flattened_grad = grad.flatten()
    threshold = np.percentile(flattened_grad, 100 - percentage * 100)

    # Create a mask of the top contributing pixels
    top_pixels_mask = grad >= threshold

    return top_pixels_mask

# Get a test image
test_images, _ = next(iter(testloader))
test_image = test_images[0].unsqueeze(0)

# Get the top 30% contributing pixels
top_pixels_mask = get_top_contributing_pixels(trained_model, test_image)
highlighted_image = test_image.squeeze().numpy()  # tensor.detach().numpy()

# Convert the mask to coordinates
mask_2d = top_pixels_mask
highlighted_image[mask_2d == 0] = 0  # Mask non-top pixels

# Plot the original and highlighted image
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.title("Original Image")
plt.imshow(test_image.squeeze().numpy(), cmap='gray')
plt.axis('off')

plt.subplot(1, 2, 2)
plt.title("Highlighted Image")
plt.imshow(highlighted_image, cmap='gray')
plt.axis('off')
plt.show()

# Configuration class
class Config:
    def __init__(self, dropout=0.5, learning_rate=0.001, num_epochs=50, batch_size=32):
        self.dropout = dropout
        self.learning_rate = learning_rate
        self.num_epochs = num_epochs
        self.batch_size = batch_size

# Base Model class (assuming you have this implemented)
class Model(nn.Module):
    def __init__(self, config, name):
        super(Model, self).__init__()
        self.config = config
        self.name = name

# FeatureNN class (assuming you have this implemented)
class FeatureNN(nn.Module):
    def __init__(self, config, name, input_shape, num_units, feature_num):
        super(FeatureNN, self).__init__()
        self.config = config
        self.name = name
        self.input_shape = input_shape
        self.num_units = num_units
        self.feature_num = feature_num
        self.fc = nn.Linear(input_shape, num_units)

    def forward(self, x):
        x = self.fc(x)
        x = F.relu(x)
        return x

# NAM model definition
class NAM(Model):
    def __init__(self, config, name, *, num_inputs: int, num_units: int) -> None:
        super(NAM, self).__init__(config, name)
        self._num_inputs = num_inputs
        self.dropout = nn.Dropout(p=self.config.dropout)

        if isinstance(num_units, list):
            assert len(num_units) == num_inputs
            self._num_units = num_units
        elif isinstance(num_units, int):
            self._num_units = [num_units for _ in range(self._num_inputs)]

        self.feature_nns = nn.ModuleList([
            FeatureNN(config=config, name=f'FeatureNN_{i}', input_shape=1, num_units=self._num_units[i], feature_num=i)
            for i in range(num_inputs)
        ])

        self.output_layer = nn.Linear(sum(self._num_units), 10)  # 10 classes for MNIST dataset
        self._bias = torch.nn.Parameter(data=torch.zeros(1))

    def calc_outputs(self, inputs: torch.Tensor) -> torch.Tensor:
        return [self.feature_nns[i](inputs[:, i:i+1]) for i in range(self._num_inputs)]

    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        individual_outputs = self.calc_outputs(inputs)
        conc_out = torch.cat(individual_outputs, dim=-1)
        dropout_out = self.dropout(conc_out)
        out = self.output_layer(dropout_out)
        return out, dropout_out

    def print_model_equation(self, feature_names):
        equation_terms = []
        feature_contributions = {}
        for i, fnn in enumerate(self.feature_nns):
            coefficients = fnn.fc.weight.data.flatten().tolist()
            intercepts = fnn.fc.bias.data.tolist()
            term = " + ".join([f"({coeff:.3f} * x_{feature_names[i]} + {intercept:.3f})" for coeff, intercept in zip(coefficients, intercepts)])
            equation_terms.append(term)
            feature_contributions[feature_names[i]] = sum(abs(c) for c in coefficients)
        equation = " + ".join(equation_terms) + f" + bias ({self._bias.item():.3f})"
        print(f"Model Equation: y = {equation}")

        # Determine feature interpretability based on coefficients
        interpretability = sorted(feature_contributions.items(), key=lambda x: x[1], reverse=True)
        print("\nFeature Contributions:")
        for feature, contribution in interpretability:
            print(f"{feature}: {contribution:.3f}")

        return interpretability[0][0]  # Return the feature with the highest contribution

# Define the config
config = Config(dropout=0.5, learning_rate=0.001, num_epochs=50, batch_size=32)

# Instantiate the NAM model
num_inputs = 28 * 28  # Number of features for MNIST images
num_units = 10  # Number of units in the hidden layer
nam_model = NAM(config=config, name='NAM_Model', num_inputs=num_inputs, num_units=num_units)

# Training function for NAM model
def train_nam_model(model, trainloader, config):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
    model.train()
    for epoch in range(config.num_epochs):
        running_loss = 0.0
        for inputs, labels in trainloader:
            optimizer.zero_grad()
            outputs, _ = model(inputs.view(inputs.size(0), -1))  # Flatten the inputs
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        if (epoch + 1) % 10 == 0:
            print(f'Epoch [{epoch + 1}/{config.num_epochs}], Loss: {running_loss / len(trainloader):.4f}')
    return model

# Train the NAM model
trained_nam_model = train_nam_model(nam_model, trainloader, config)

# Evaluate the NAM model
def evaluate_nam_model(model, testloader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in testloader:
            outputs, _ = model(inputs.view(inputs.size(0), -1))  # Flatten the inputs
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = correct / total
    print(f'Accuracy: {accuracy * 100:.2f}%')

evaluate_nam_model(trained_nam_model, testloader)

# Print the model equation and get the most contributing feature
feature_names = [f'pixel_{i}' for i in range(num_inputs)]
most_contributing_feature = trained_nam_model.print_model_equation(feature_names)
print(f"\nMost contributing feature: {most_contributing_feature}")


RuntimeError: Can't call numpy() on Tensor that requires grad. Use tensor.detach().numpy() instead.