In [None]:
import matplotlib.pyplot as plt
import numpy as np
from collections import Counter
from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose, Normalize, ToTensor
from torch.utils.data import Subset
import random
import seaborn as sns
from matplotlib import rcParams

# Imposta il font CMU Serif
#plt.rcParams['font.family'] = 'CMU Serif'

# Funzione per creare client IID con distribuzione perfetta
def create_perfect_iid_clients(num_clients, dataset, samples_per_client=2500):
    class_to_indices = {i: [] for i in range(10)}
    for idx, (_, label) in enumerate(dataset):
        class_to_indices[label].append(idx)

    clients_data = []
    for _ in range(num_clients):
        selected_indices = []
        for cls in range(10):
            selected = random.sample(class_to_indices[cls], samples_per_client // 10)
            selected_indices.extend(selected)
        subset_labels = [dataset[idx][1] for idx in selected_indices]
        class_distribution = Counter(subset_labels)
        clients_data.append((Subset(dataset, selected_indices), class_distribution))
    return clients_data

# Funzione per creare client Non-IID
def create_non_iid_clients(num_clients, dataset, samples_per_client=2500, alpha=0.5):
    clients_data = []
    for _ in range(num_clients):
        class_to_indices = {i: [] for i in range(10)}
        for idx, (_, label) in enumerate(dataset):
            class_to_indices[label].append(idx)

        proportions = np.random.dirichlet([alpha] * 10)
        class_counts = (proportions * samples_per_client).astype(int)
        selected_indices = []
        for cls, count in enumerate(class_counts):
            available_indices = class_to_indices[cls]
            selected = random.sample(available_indices, min(count, len(available_indices)))
            selected_indices.extend(selected)

        subset_labels = [dataset[i][1] for i in selected_indices]
        class_distribution = Counter(subset_labels)
        clients_data.append((Subset(dataset, selected_indices), class_distribution))
    return clients_data

# Funzione per migliorare la distribuzione usando GAN augmentation
def augment_with_gan(clients_data, target_samples_per_class=250):
    augmented_clients_data = []
    for subset, class_distribution in clients_data:
        new_class_distribution = class_distribution.copy()
        for cls in range(10):
            if new_class_distribution[cls] < target_samples_per_class:
                additional_samples = target_samples_per_class - new_class_distribution[cls]
                new_class_distribution[cls] += additional_samples
        augmented_clients_data.append((subset, new_class_distribution))
    return augmented_clients_data

# Caricamento del dataset CIFAR-10
transform = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
cifar10_train = CIFAR10(root="./", train=True, download=True, transform=transform)

# Creazione di 4 clients IID perfetti e 4 clients Non-IID
iid_clients = create_perfect_iid_clients(4, cifar10_train, samples_per_client=2500)
non_iid_clients = create_non_iid_clients(4, cifar10_train, samples_per_client=2500, alpha=0.5)

# Applicazione di GAN augmentation ai clients Non-IID
non_iid_clients_augmented = augment_with_gan(non_iid_clients)

# Unione delle distribuzioni
clients_distributions_before = [client[1] for client in iid_clients + non_iid_clients]
clients_distributions_after = [client[1] for client in iid_clients + non_iid_clients_augmented]

# Funzione per creare un grafico
def plot_distributions(distributions):
    distributions_percentages = []
    for distribution in distributions:
        total_samples = sum(distribution.values())
        percentages = [distribution.get(cls, 0) / total_samples * 100 for cls in range(10)]
        distributions_percentages.append(percentages)

    # Colori delle classi (tema seaborn: rosso-verde)
    cmap = sns.color_palette("RdYlGn", n_colors=10)

    # Creazione del grafico
    fig, ax = plt.subplots(figsize=(12, 8))

    # Parametri per la separazione tra le righe
    row_gap = 1.5  # Spazio tra i clients
    rows = np.arange(len(distributions_percentages)) * row_gap

    # Creazione delle barre cumulative per ogni client
    for i, percentages in enumerate(distributions_percentages):
        left_offset = 0
        for j, percentage in enumerate(percentages):
            ax.barh(rows[i], percentage, color=cmap[j], edgecolor="black", height=0.8, left=left_offset)
            left_offset += percentage

    # Dettagli del grafico
    ax.set_xlabel("Class Distribution (%)", fontsize=12)
    ax.set_ylabel("Clients", fontsize=12)
    ax.set_yticks(rows)
    ax.set_yticklabels(
        [f"Client {i + 1}" for i in range(4)] + [f"Client {i + 5}" for i in range(4)]
    )
    ax.set_xticks(np.arange(0, 101, 25))
    ax.grid(axis="x", linestyle="--", alpha=0.5)

    # Legenda al centro sopra il grafico
    legend_labels = [
        "Airplane", "Automobile", "Bird", "Cat", "Deer",
        "Dog", "Frog", "Horse", "Ship", "Truck"
    ]
    handles = [plt.Rectangle((0, 0), 1, 1, color=color) for color in cmap]
    fig.legend(handles, legend_labels, title="Class", loc="upper center", bbox_to_anchor=(0.55, 1.11), ncol=10, fontsize=10)

    plt.tight_layout(rect=[0, 0, 1, 0.95])  # Aggiustare lo spazio per la legenda
    plt.show()

# Grafico 1: Prima del miglioramento
plot_distributions(clients_distributions_before)

# Grafico 2: Dopo il miglioramento con GAN
plot_distributions(clients_distributions_after)