In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import math
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset, Dataset, ConcatDataset
import torchvision.datasets as datasets
import time
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import StepLR
import os
import numpy as np
from PIL import Image
import csv
import random
import kagglehub
import glob
from sklearn.model_selection import train_test_split
from torchvision.transforms import ToPILImage

# Ensure output directory exists
output_dir = "FL_VEHICLE_NON_IID_ANOTHER_DISTRIBUTION_500AUG"
os.makedirs(output_dir, exist_ok=True)

# Download Vehicle Type Image Dataset from Kaggle
try:
    path = kagglehub.dataset_download("sujaykapadnis/vehicle-type-image-dataset")
    print("Path to dataset files:", path)
    dataset_path = path
except Exception as e:
    print(f"Failed to download dataset: {e}")
    raise

# Base transform for RGB Vehicle Type Dataset
base_transform = transforms.Compose([
    transforms.Resize((128, 128)),  # Resize to 128x128
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize RGB channels
])

# Debug dataset directory structure
print("Inspecting dataset path:", dataset_path)
for root, dirs, files in os.walk(dataset_path):
    print(f"Root: {root}")
    print(f"Dirs: {dirs}")
    print(f"Files (first 5): {files[:5]}")
    print("-" * 50)

class VehicleTypeDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.images = []
        self.labels = []
        self.class_names = []
        self.class_to_idx = {}

        print(f"Searching for images in {root_dir}")
        for root, dirs, files in os.walk(root_dir):
            image_files = [f for f in files if f.lower().endswith(('.jpg', '.png', '.jpeg'))]
            if image_files:
                class_name = os.path.basename(root)
                if class_name not in self.class_to_idx:
                    self.class_names.append(class_name)
                    self.class_to_idx[class_name] = len(self.class_names) - 1
                for img_file in image_files:
                    img_path = os.path.join(root, img_file)
                    self.images.append(img_path)
                    self.labels.append(self.class_to_idx[class_name])

        if not self.images:
            raise ValueError(
                f"No images found in {root_dir}. "
                "Expected class folders containing .jpg, .png, or .jpeg images."
            )

        print(f"Found {len(self.images)} images across {len(self.class_names)} classes.")
        print(f"Classes: {self.class_names}")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        label = self.labels[idx]
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, label

# Load the Vehicle Type Dataset without transformations for splitting
try:
    dataset_no_transform = VehicleTypeDataset(root_dir=dataset_path, transform=None)
except Exception as e:
    print(f"Failed to load dataset: {e}")
    raise

# Update label_dim
label_dim = len(dataset_no_transform.class_names)
print(f"Number of classes (label_dim): {label_dim}")

# Step 1: Split dataset into train, validation, and test sets per class
validation_ratio = 0.1
test_ratio = 0.1
train_ratio = 0.8

# Separate the dataset by class using labels directly
class_datasets = [[] for _ in range(label_dim)]
for idx in range(len(dataset_no_transform)):
    label = dataset_no_transform.labels[idx]  # Directly access labels without loading images
    class_datasets[label].append(idx)

# Split each class into train, validation, and test sets
train_indices_per_class = []
val_indices_per_class = []
test_indices_per_class = []

for class_idx in range(label_dim):
    indices = class_datasets[class_idx]
    total_samples = len(indices)
    num_train = int(total_samples * train_ratio)
    num_val = int(total_samples * validation_ratio)
    num_test = total_samples - num_train - num_val  # Ensure all samples are used

    # Shuffle indices for this class
    random.shuffle(indices)

    # Split indices
    train_indices = indices[:num_train]
    val_indices = indices[num_train:num_train + num_val]
    test_indices = indices[num_train + num_val:]

    train_indices_per_class.append(train_indices)
    val_indices_per_class.append(val_indices)
    test_indices_per_class.append(test_indices)

    print(f"Class {class_idx}: Train={len(train_indices)}, Val={len(val_indices)}, Test={len(test_indices)}")

# Verify no overlap between train, val, and test sets
for class_idx in range(label_dim):
    train_set = set(train_indices_per_class[class_idx])
    val_set = set(val_indices_per_class[class_idx])
    test_set = set(test_indices_per_class[class_idx])
    
    assert len(train_set.intersection(val_set)) == 0, f"Overlap between train and val for class {class_idx}"
    assert len(train_set.intersection(test_set)) == 0, f"Overlap between train and test for class {class_idx}"
    assert len(val_set.intersection(test_set)) == 0, f"Overlap between val and test for class {class_idx}"

# Create a new dataset instance with transformations for training
dataset = VehicleTypeDataset(root_dir=dataset_path, transform=base_transform)

# Create val and test datasets
val_dataset = Subset(dataset, [idx for class_indices in val_indices_per_class for idx in class_indices])
test_dataset = Subset(dataset, [idx for class_indices in test_indices_per_class for idx in class_indices])

print(f"Validation samples: {len(val_dataset)}, Test samples: {len(test_dataset)}")

# Step 2: Define augmentation transforms
augmentation_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),  # Horizontal flip with 50% probability
    transforms.RandomRotation(degrees=10),  # Rotate by up to 10 degrees
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),  # Small translation and scaling
])

# Function to apply augmentation only to PIL images
def augment_image_if_needed(image):
    if isinstance(image, torch.Tensor):
        image = ToPILImage()(image)
    image = augmentation_transform(image)
    image = base_transform(image)  # Reapply base_transform after augmentation
    return image

# Function to augment the dataset to a target length
def augment_dataset(dataset, target_length):
    augmented_samples = []
    current_length = len(dataset)
    num_samples_to_augment = target_length - current_length
    
    if num_samples_to_augment <= 0:
        indices = list(range(current_length))
        random.shuffle(indices)
        indices = indices[:target_length]  # Downsample to target length
        return Subset(dataset, indices)
    
    for _ in range(num_samples_to_augment):
        index = random.randint(0, current_length - 1)
        image, label = dataset[index]
        augmented_image = augment_image_if_needed(image)
        augmented_samples.append((augmented_image, label))
    
    augmented_dataset = ConcatDataset([dataset, augmented_samples])
    return augmented_dataset

# Step 3: Create separate datasets for each class and split them as per the new distribution
distinct_class_datasets = []
num_classes = label_dim
for class_idx in range(num_classes):
    distinct_class_dataset = Subset(dataset, train_indices_per_class[class_idx])
    distinct_class_datasets.append(distinct_class_dataset)

# Verify the size of each class dataset
for i, distinct_class_dataset in enumerate(distinct_class_datasets):
    print(f"Class {i} dataset size: {len(distinct_class_dataset)}")

# Function to split a dataset into two parts
def split_dataset(dataset, split_ratio):
    train_size = int(np.round(split_ratio * len(dataset)))
    remaining_size = len(dataset) - train_size
    train_dataset, remaining_dataset = torch.utils.data.random_split(dataset, [train_size, remaining_size])
    return train_dataset, remaining_dataset

# Split each class dataset into two halves (50/50)
split_ratio = 0.5
split_datasets = []
train_class_datasets1 = []
train_class_datasets2 = []

for distinct_class_dataset in distinct_class_datasets:
    train_class_dataset1, train_class_dataset2 = split_dataset(distinct_class_dataset, split_ratio)
    split_datasets.append((train_class_dataset1, train_class_dataset2))
    train_class_datasets1.append(train_class_dataset1)
    train_class_datasets2.append(train_class_dataset2)

for i, (train_class_dataset1, train_class_dataset2) in enumerate(split_datasets):
    print(f"Class {i}:")
    print(f"  Number of samples in train_class_datasets1: {len(train_class_dataset1)}")
    print(f"  Number of samples in train_class_datasets2: {len(train_class_dataset2)}")

# Further split train_class_datasets2 into 70% and 30% parts
split_ratio = 0.7
split_datasets2 = []
train_class_datasets2_part1 = []
train_class_datasets2_part2 = []

for class_dataset in train_class_datasets2:
    part1_dataset, part2_dataset = split_dataset(class_dataset, split_ratio)
    split_datasets2.append((part1_dataset, part2_dataset))
    train_class_datasets2_part1.append(part1_dataset)
    train_class_datasets2_part2.append(part2_dataset)

for i, (part1_dataset, part2_dataset) in enumerate(split_datasets2):
    print(f"Class {i}:")
    print(f"  Number of samples in train_class_datasets2_part1 (70%): {len(part1_dataset)}")
    print(f"  Number of samples in train_class_datasets2_part2 (30%): {len(part2_dataset)}")

# Target length for all datasets
lengthiest_length = 500

# Augment each dataset to have exactly 500 samples
train_class_datasets1 = [augment_dataset(dataset, lengthiest_length) for dataset in train_class_datasets1]
train_class_datasets2_augmented = [augment_dataset(dataset, lengthiest_length) for dataset in train_class_datasets2]
train_class_datasets2_part1 = [augment_dataset(dataset, lengthiest_length) for dataset in train_class_datasets2_part1]
train_class_datasets2_part2 = [augment_dataset(dataset, lengthiest_length) for dataset in train_class_datasets2_part2]

# Print the sizes of the augmented datasets
print("Sizes of augmented datasets:")
for i, dataset in enumerate(train_class_datasets1):
    print(f"  Length of augmented train_class_datasets1[{i}]: {len(dataset)}")
for i, dataset in enumerate(train_class_datasets2_augmented):
    print(f"  Length of augmented train_class_datasets2[{i}]: {len(dataset)}")
for i, dataset in enumerate(train_class_datasets2_part1):
    print(f"  Length of augmented train_class_datasets2_part1[{i}]: {len(dataset)}")
for i, dataset in enumerate(train_class_datasets2_part2):
    print(f"  Length of augmented train_class_datasets2_part2[{i}]: {len(dataset)}")

# Step 4: Distribute datasets to users as per the new scheme
Num_users = 5
user_data = []
user_indices_sets = [[] for _ in range(Num_users)]

# Helper function to collect indices from a dataset (handles Subset and ConcatDataset)
def collect_indices(dataset):
    indices = []
    if isinstance(dataset, Subset):
        indices.extend(dataset.indices)
    elif isinstance(dataset, ConcatDataset):
        for sub_dataset in dataset.datasets:
            indices.extend(collect_indices(sub_dataset))
    elif isinstance(dataset, list):
        # If dataset is a list of tuples (image, label)
        # We can't track indices directly; this case occurs with augmented samples
        # We'll handle this by not adding these indices to user_indices_sets
        pass
    else:
        raise ValueError(f"Unsupported dataset type: {type(dataset)}")
    return indices

# Distribute data to users using the augmented train_class_datasets2
user_data.append(ConcatDataset([
    train_class_datasets1[0],
    train_class_datasets2_augmented[1],  # Use the augmented version
    train_class_datasets2_part2[2],
    train_class_datasets2_part2[3],
    train_class_datasets2_part2[4]
]))
user_data.append(ConcatDataset([
    train_class_datasets1[1],
    train_class_datasets2_part1[2]
]))
user_data.append(ConcatDataset([
    train_class_datasets1[2],
    train_class_datasets2_part1[3]
]))
user_data.append(ConcatDataset([
    train_class_datasets1[3],
    train_class_datasets2_part1[4]
]))
user_data.append(ConcatDataset([
    train_class_datasets1[4]
]))

# Collect indices for overlap checking (only original indices, not augmented ones)
for user_idx, user_dataset in enumerate(user_data):
    for sub_dataset in user_dataset.datasets:
        if isinstance(sub_dataset, Subset):
            user_indices_sets[user_idx].extend(sub_dataset.indices)
        elif isinstance(sub_dataset, ConcatDataset):
            user_indices_sets[user_idx].extend(collect_indices(sub_dataset))
    print(f"User {user_idx + 1}: Total samples = {len(user_dataset)}")

# Step 4.1: Compute class distribution before and after augmentation for verification
samples_before_after = {user_idx: {class_idx: {'before': 0, 'after': 0} for class_idx in range(label_dim)} for user_idx in range(Num_users)}

# Compute "before" counts (original samples, excluding augmented ones)
for user_idx in range(Num_users):
    user_dataset = user_data[user_idx]
    for sub_dataset in user_dataset.datasets:
        if isinstance(sub_dataset, Subset):
            # Count original samples
            for idx in sub_dataset.indices:
                label = dataset_no_transform.labels[idx]
                samples_before_after[user_idx][label]['before'] += 1

# Compute "after" counts (total samples, including augmented ones)
class_counts_per_user = []
for user_idx in range(Num_users):
    user_dataset = user_data[user_idx]
    class_counts = [0] * label_dim
    for idx in range(len(user_dataset)):
        _, label = user_dataset[idx]
        class_counts[label] += 1
        samples_before_after[user_idx][label]['after'] = class_counts[label]
    class_counts_per_user.append(class_counts)
    print(f"User {user_idx + 1} Class Distribution (After Augmentation): {class_counts}")
    total_samples = len(user_dataset)
    class_percentages = [count / total_samples * 100 if total_samples > 0 else 0 for count in class_counts]
    print(f"User {user_idx + 1} Class Percentages (After Augmentation): {[f'{p:.2f}%' for p in class_percentages]}")

# Step 4.2: Plot histograms for each class showing sample distribution across users (before augmentation)
print("\n=== Plotting Histograms for Each Class (Before Augmentation) ===")
for class_idx in range(label_dim):
    sample_counts = [samples_before_after[user_idx][class_idx]['before'] for user_idx in range(Num_users)]
    
    plt.figure(figsize=(8, 6))
    plt.bar(range(Num_users), sample_counts, align='center', alpha=0.7)
    plt.xlabel('User Index')
    plt.ylabel('Number of Samples')
    plt.title(f"Class {class_idx} ({dataset_no_transform.class_names[class_idx]}) Sample Distribution Across Users (Before Augmentation)")
    plt.xticks(range(Num_users), [f"User {i+1}" for i in range(Num_users)])
    plt.grid(True, axis='y')
    histogram_path = os.path.join(output_dir, f"class_{class_idx}_sample_distribution_histogram_before_augmentation.png")
    plt.savefig(histogram_path)
    plt.show()
    print(f"Saved histogram for Class {class_idx} to {histogram_path}")

# Print samples before and after augmentation for each user
print("\n=== Samples Before and After Augmentation ===")
for user_idx in range(Num_users):
    print(f"\nUser {user_idx + 1}:")
    for class_idx in range(label_dim):
        before = samples_before_after[user_idx][class_idx]['before']
        after = samples_before_after[user_idx][class_idx]['after']
        print(f"Class {class_idx} ({dataset_no_transform.class_names[class_idx]}): Before={before}, After={after}")

# Step 4.3: Verify non-overlapping samples across users (based on original indices)
print("\n=== Verifying Data Distribution Across Users ===")
for i in range(Num_users):
    for j in range(i + 1, Num_users):
        overlap = set(user_indices_sets[i]).intersection(set(user_indices_sets[j]))
        print(f"Overlap between User {i+1} and User {j+1}: {len(overlap)} samples")
        if len(overlap) > 0:
            print(f"Overlapping indices: {overlap}")

# Step 5: Visualize 5 images per class for each user in a 5x5 grid
display_transform = transforms.Compose([
    transforms.Resize((128, 128), interpolation=transforms.InterpolationMode.LANCZOS),  # High-quality resizing
    transforms.RandomAdjustSharpness(sharpness_factor=2.0, p=1.0),  # Apply sharpening
    transforms.ToTensor()
])

vis_dataset = VehicleTypeDataset(root_dir=dataset_path, transform=display_transform)

for user_idx in range(Num_users):
    print(f"\nCollecting images for User {user_idx + 1}...")
    user_indices = user_indices_sets[user_idx]

    # Collect indices for each class (up to 5 images per class)
    images_per_class = {class_idx: [] for class_idx in range(label_dim)}
    for idx in user_indices:
        label = dataset_no_transform.labels[idx]
        if len(images_per_class[label]) < 5:
            images_per_class[label].append(idx)

    for class_idx in range(label_dim):
        num_images = len(images_per_class[class_idx])
        print(f"User {user_idx + 1}, Class {class_idx} ({dataset_no_transform.class_names[class_idx]}): {num_images} images collected")
        if num_images < 5:
            print(f"Warning: User {user_idx + 1}, Class {class_idx} has only {num_images} images (less than 5).")

    # Plot images in a 5x5 grid (5 classes, 5 images each)
    fig, axes = plt.subplots(5, 5, figsize=(12, 12))  # 5x5 grid
    fig.suptitle(f"User {user_idx + 1}: 5 Images Per Class (5x5 Grid, Non-IID)", fontsize=16)

    for row in range(5):  # 5 rows for 5 classes
        class_idx = row  # Each row corresponds to a class
        class_name = dataset_no_transform.class_names[class_idx]
        for col in range(5):  # 5 columns for 5 images per class
            ax = axes[row, col]
            if col < len(images_per_class[class_idx]):
                vis_idx = images_per_class[class_idx][col]
                img, _ = vis_dataset[vis_idx]  # Load the image with display_transform
                img = img.permute(1, 2, 0).numpy()  # Convert to HWC for display
                img = np.clip(img, 0, 1)  # Clip to [0, 1]
                ax.imshow(img)
                if col == 0:  # Show class name on the first image of each row
                    ax.set_ylabel(class_name, rotation=90, labelpad=10, fontsize=10)
            else:
                ax.text(0.5, 0.5, 'No Image', ha='center', va='center', fontsize=8)
            ax.axis('off')

    plt.tight_layout(rect=[0, 0, 1, 0.95])
    image_save_path = os.path.join(output_dir, f"user_{user_idx + 1}_images_5x5_grid_non_iid.png")
    plt.savefig(image_save_path)
    plt.show()  # Show the plot
    print(f"Saved 5x5 visualization for User {user_idx + 1} to {image_save_path}")

# ResNet50 Model for RGB
class ResNet50Model(nn.Module):
    def __init__(self, num_classes):
        super(ResNet50Model, self).__init__()
        # Load pre-trained ResNet50
        self.resnet50 = models.resnet50(pretrained=True)
        # Freeze all layers except the final fully connected layer (optional, can unfreeze for fine-tuning)
        # for param in self.resnet50.parameters():
        #     param.requires_grad = False
        # Modify the fully connected layer to match the number of classes
        in_features = self.resnet50.fc.in_features
        self.resnet50.fc = nn.Linear(in_features, num_classes)

    def forward(self, x):
        x = self.resnet50(x)
        return F.log_softmax(x, dim=1)

# Federated Learning Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
global_model = CNN().to(device)
local_models = [CNN().to(device) for _ in range(Num_users)]
batch_size = 128
num_local_epochs = 3
num_rounds = 100

communication_rounds = []
global_accuracies = []
global_losses = []

for round in range(num_rounds):
    round_start_time = time.time()
    print(f"\nRound {round + 1}")
    
    for user_idx in range(Num_users):
        local_models[user_idx].load_state_dict(global_model.state_dict())
        local_model = local_models[user_idx]
        local_tr_data_loader = DataLoader(user_data[user_idx], batch_size=batch_size, shuffle=True)
        local_val_data_loader = DataLoader(val_dataset, batch_size=batch_size)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.SGD(local_model.parameters(), lr=0.001, momentum=0.9)
        scheduler = StepLR(optimizer, step_size=5, gamma=0.1)

        for epoch in range(num_local_epochs):
            epoch_start_time = time.time()
            
            local_model.train()
            running_loss = 0.0
            correct_train = 0
            total_train = 0
            for data, target in local_tr_data_loader:
                data, target = data.to(device), target.to(device)
                optimizer.zero_grad()
                output = local_model(data)
                loss = criterion(output, target)
                loss.backward()
                optimizer.step()
                running_loss += loss.item()
                _, predicted = output.max(1)
                total_train += target.size(0)
                correct_train += predicted.eq(target).sum().item()

            train_loss = running_loss / len(local_tr_data_loader) if len(local_tr_data_loader) > 0 else float('inf')
            train_accuracy = 100 * correct_train / total_train if total_train > 0 else 0.0

            local_model.eval()
            val_loss = 0.0
            correct_val = 0
            total_val = 0
            with torch.no_grad():
                for data, target in local_val_data_loader:
                    data, target = data.to(device), target.to(device)
                    output = local_model(data)
                    val_loss += criterion(output, target).item()
                    _, predicted = output.max(1)
                    total_val += target.size(0)
                    correct_val += predicted.eq(target).sum().item()

            val_loss /= len(local_val_data_loader) if len(local_val_data_loader) > 0 else 1
            val_accuracy = 100 * correct_val / total_val if total_val > 0 else 0.0
            
            epoch_end_time = time.time()
            epoch_duration = epoch_end_time - epoch_start_time
            print(f"User {user_idx + 1}, Epoch {epoch + 1}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, "
                  f"Train Acc: {train_accuracy:.2f}%, Val Acc: {val_accuracy:.2f}%, "
                  f"Time Taken: {epoch_duration:.2f} seconds")
            scheduler.step()

    # Federated Averaging
    global_state_dict = global_model.state_dict()
    for key in global_state_dict.keys():
        local_state_dicts = [local_model.state_dict() for local_model in local_models]
        global_state_dict[key] = torch.mean(torch.stack([d[key].float() for d in local_state_dicts]), dim=0)
    global_model.load_state_dict(global_state_dict)

    # Global Evaluation
    global_model.eval()
    global_loss = 0.0
    global_accuracy = 0.0
    test_loader = DataLoader(test_dataset, batch_size=batch_size)
    with torch.no_grad():
        for data, labels in test_loader:
            data, labels = data.to(device), labels.to(device)
            outputs = global_model(data)
            global_loss += criterion(outputs, labels).item()
            _, predicted = torch.max(outputs, 1)
            global_accuracy += (predicted == labels).sum().item()

    global_accuracy = 100 * global_accuracy / len(test_dataset) if len(test_dataset) > 0 else 0.0
    global_loss /= len(test_loader) if len(test_loader) > 0 else 1
    
    round_end_time = time.time()
    round_duration = round_end_time - round_start_time
    print(f"Round {round + 1}: Global Test Accuracy: {global_accuracy:.2f}%, Global Loss: {global_loss:.4f}, "
          f"Time Taken: {round_duration:.2f} seconds")
    
    communication_rounds.append(round + 1)
    global_accuracies.append(global_accuracy)
    global_losses.append(global_loss)

# Save the global model
model_save_path = os.path.join(output_dir, "cnn2_128_500_rgb_sgd_global_model_non_iid_ANO_DIST.pth")
torch.save(global_model.state_dict(), model_save_path)
print(f"Global model saved to {model_save_path}")

# Save Results
csv_file_path = os.path.join(output_dir, "cnn2_128_500_vehicle_rgb_sgd_non_iid_ANO_DIST.csv")
accuracy_plot_file_path = os.path.join(output_dir, "cnn2_128_500_vehicle_rgb_sgd_accuracy_non_iid_ANO_DIST.png")
loss_plot_file_path = os.path.join(output_dir, "cnn2_128_500_vehicle_rgb_sgd_loss_non_iid_ANO_DIST.png")

with open(csv_file_path, 'w', newline='') as csvfile:
    writer = csv.writer(csvfile)
    writer.writerow(['Communication Rounds', 'Global Accuracies', 'Global Losses'])
    writer.writerows(zip(communication_rounds, global_accuracies, global_losses))

# Display and Save Global Accuracy Plot
plt.figure(figsize=(10, 5))
plt.plot(communication_rounds, global_accuracies, marker='s', linestyle='-', color='b', label='CNN2 128 500(AUG) RGB SGD Global Accuracies (Non-IID)')
plt.xlabel('Communication Rounds')
plt.ylabel('Global Accuracies')
plt.title('Global Accuracies vs. Communication Rounds (CNN2 128 RGB SGD, Non-IID)')
plt.grid(True)
plt.legend()
plt.savefig(accuracy_plot_file_path)
plt.show()  # Show the accuracy plot
print(f"Saved accuracy plot to {accuracy_plot_file_path}")

# Display and Save Global Loss Plot
plt.figure(figsize=(10, 5))
plt.plot(communication_rounds, global_losses, marker='o', linestyle='--', color='r', label='CNN2 128 500(AUG) RGB SGD Global Losses (Non-IID)')
plt.xlabel('Communication Rounds')
plt.ylabel('Global Losses')
plt.title('Global Losses vs. Communication Rounds (CNN2 128 RGB SGD, Non-IID)')
plt.grid(True)
plt.legend()
plt.savefig(loss_plot_file_path)
plt.show()  # Show the loss plot
print(f"Saved loss plot to {loss_plot_file_path}")

# Plot histogram for User 1's class distribution
plt.figure(figsize=(8, 6))
plt.hist(range(label_dim), weights=class_counts_per_user[0], bins=range(label_dim + 1), align='left', rwidth=0.8)
plt.xlabel('Class Index')
plt.ylabel('Number of Samples')
plt.title("User 1 Class Distribution Histogram (Non-IID)")
plt.xticks(range(label_dim), dataset_no_transform.class_names, rotation=45)
plt.grid(True)
histogram_path = os.path.join(output_dir, "user_1_class_distribution_histogram.png")
plt.savefig(histogram_path)
plt.show()
print(f"Saved User 1 class distribution histogram to {histogram_path}")