In [4]:
from collections import defaultdict
import numpy as np
import torch
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader

# Define dataset path
dataset_path = "C:\\Users\\USER\\Downloads\\Medical Minist Dataset"

# Define transformations
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Load dataset
full_dataset = datasets.ImageFolder(root=dataset_path)

# Organize data by class
class_indices = defaultdict(list)
for idx, (_, label) in enumerate(full_dataset.samples):
    class_indices[label].append(idx)

# Ensure stratified sampling
num_clients = 20
train_ratio = 0.8  # 80% training, 20% testing

client_indices = [[] for _ in range(num_clients)]

# Distribute data across clients while keeping class balance
for label, indices in class_indices.items():
    np.random.shuffle(indices)
    splits = np.array_split(indices, num_clients)
   
    for client_id, split in enumerate(splits):
        client_indices[client_id].extend(split.tolist())

# Ensure exact same number of samples per client
min_train_size = min(len(indices) * train_ratio for indices in client_indices)
min_test_size = min(len(indices) * (1 - train_ratio) for indices in client_indices)

for i in range(num_clients):
    np.random.shuffle(client_indices[i])
    client_indices[i] = client_indices[i][:int(min_train_size + min_test_size)]

# Custom dataset wrapper to apply transformations
class CustomSubset(Dataset):
    def __init__(self, dataset, indices, transform=None):
        self.dataset = dataset
        self.indices = indices
        self.transform = transform
   
    def __getitem__(self, idx):
        image, label = self.dataset[self.indices[idx]]
        if self.transform:
            image = self.transform(image)
        return image, label
   
    def __len__(self):
        return len(self.indices)

# Split each client's data into training and testing sets
train_loaders, test_loaders = [], []
train_sizes, test_sizes = [], []

batch_size = 32
for i, indices in enumerate(client_indices):
    split_idx = int(len(indices) * train_ratio)
   
    train_indices = indices[:split_idx]
    test_indices = indices[split_idx:]

    train_dataset = CustomSubset(full_dataset, train_indices, transform=train_transform)
    test_dataset = CustomSubset(full_dataset, test_indices, transform=test_transform)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False)

    train_loaders.append(train_loader)
    test_loaders.append(test_loader)

    train_sizes.append(len(train_dataset))
    test_sizes.append(len(test_dataset))

    print(f"✅ Client {i+1} Data Loaded: Train={len(train_dataset)}, Test={len(test_dataset)}")

✅ Client 1 Data Loaded: Train=2357, Test=590
✅ Client 2 Data Loaded: Train=2357, Test=590
✅ Client 3 Data Loaded: Train=2357, Test=590
✅ Client 4 Data Loaded: Train=2357, Test=590
✅ Client 5 Data Loaded: Train=2357, Test=590
✅ Client 6 Data Loaded: Train=2357, Test=590
✅ Client 7 Data Loaded: Train=2357, Test=590
✅ Client 8 Data Loaded: Train=2357, Test=590
✅ Client 9 Data Loaded: Train=2357, Test=590
✅ Client 10 Data Loaded: Train=2357, Test=590
✅ Client 11 Data Loaded: Train=2357, Test=590
✅ Client 12 Data Loaded: Train=2357, Test=590
✅ Client 13 Data Loaded: Train=2357, Test=590
✅ Client 14 Data Loaded: Train=2357, Test=590
✅ Client 15 Data Loaded: Train=2357, Test=590
✅ Client 16 Data Loaded: Train=2357, Test=590
✅ Client 17 Data Loaded: Train=2357, Test=590
✅ Client 18 Data Loaded: Train=2357, Test=590
✅ Client 19 Data Loaded: Train=2357, Test=590
✅ Client 20 Data Loaded: Train=2357, Test=590


In [6]:
import torch
import torchvision.models as models
import numpy as np
from tqdm import tqdm

# Define device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load pretrained ResNet-18 (WITHOUT last layers) for consistency
resnet_model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)  
resnet_model = torch.nn.Sequential(*list(resnet_model.children())[:-2])  # Consistent feature extractor
resnet_model.eval().to(device)

# Feature extraction function
def extract_features(data_loader, dataset_type, client_id, model):
    feature_list, labels_list = [], []
    with torch.no_grad():
        for images, labels in tqdm(data_loader, desc=f"Extracting {dataset_type} Features for Client {client_id}"):
            images = images.to(device)
            features = model(images)  # Extract features
            features = features.mean(dim=[2, 3])  # Global Average Pooling
            feature_list.append(features.cpu().numpy())
            labels_list.append(labels.numpy())

    features_array = np.vstack(feature_list)
    labels_array = np.concatenate(labels_list)

    # Save extracted features and labels
    np.save(f"client_{client_id}_{dataset_type}_features.npy", features_array)
    np.save(f"client_{client_id}_{dataset_type}_labels.npy", labels_array)

    print(f"✅ Client {client_id} {dataset_type} Features Extracted: {features_array.shape}")

    return features_array, labels_array

# Extract features for clients
client_train_features, client_test_features = [], []
client_train_labels, client_test_labels = [], []

for i, (train_loader, test_loader) in enumerate(zip(train_loaders, test_loaders)):
    train_features, train_labels = extract_features(train_loader, "train", i+1, resnet_model)
    test_features, test_labels = extract_features(test_loader, "test", i+1, resnet_model)  # Using the same model

    client_train_features.append(train_features)
    client_train_labels.append(train_labels)
    client_test_features.append(test_features)
    client_test_labels.append(test_labels)

print("\n✅ Feature Extraction Complete!")

# Debugging Feature Shapes
for i, features in enumerate(client_train_features):
    print(f"Client {i+1}: Train Shape = {features.shape}, Test Shape = {client_test_features[i].shape}")

Extracting train Features for Client 1: 100%|██████████| 74/74 [04:01<00:00,  3.27s/it]


✅ Client 1 train Features Extracted: (2357, 512)


Extracting test Features for Client 1: 100%|██████████| 19/19 [01:02<00:00,  3.31s/it]


✅ Client 1 test Features Extracted: (590, 512)


Extracting train Features for Client 2: 100%|██████████| 74/74 [03:58<00:00,  3.23s/it]


✅ Client 2 train Features Extracted: (2357, 512)


Extracting test Features for Client 2: 100%|██████████| 19/19 [00:54<00:00,  2.87s/it]


✅ Client 2 test Features Extracted: (590, 512)


Extracting train Features for Client 3: 100%|██████████| 74/74 [03:47<00:00,  3.07s/it]


✅ Client 3 train Features Extracted: (2357, 512)


Extracting test Features for Client 3: 100%|██████████| 19/19 [00:53<00:00,  2.82s/it]


✅ Client 3 test Features Extracted: (590, 512)


Extracting train Features for Client 4: 100%|██████████| 74/74 [03:47<00:00,  3.08s/it]


✅ Client 4 train Features Extracted: (2357, 512)


Extracting test Features for Client 4: 100%|██████████| 19/19 [00:53<00:00,  2.82s/it]


✅ Client 4 test Features Extracted: (590, 512)


Extracting train Features for Client 5: 100%|██████████| 74/74 [03:41<00:00,  3.00s/it]


✅ Client 5 train Features Extracted: (2357, 512)


Extracting test Features for Client 5: 100%|██████████| 19/19 [00:52<00:00,  2.78s/it]


✅ Client 5 test Features Extracted: (590, 512)


Extracting train Features for Client 6: 100%|██████████| 74/74 [03:45<00:00,  3.04s/it]


✅ Client 6 train Features Extracted: (2357, 512)


Extracting test Features for Client 6: 100%|██████████| 19/19 [00:55<00:00,  2.94s/it]


✅ Client 6 test Features Extracted: (590, 512)


Extracting train Features for Client 7: 100%|██████████| 74/74 [03:38<00:00,  2.96s/it]


✅ Client 7 train Features Extracted: (2357, 512)


Extracting test Features for Client 7: 100%|██████████| 19/19 [00:53<00:00,  2.80s/it]


✅ Client 7 test Features Extracted: (590, 512)


Extracting train Features for Client 8: 100%|██████████| 74/74 [03:39<00:00,  2.97s/it]


✅ Client 8 train Features Extracted: (2357, 512)


Extracting test Features for Client 8: 100%|██████████| 19/19 [00:50<00:00,  2.67s/it]


✅ Client 8 test Features Extracted: (590, 512)


Extracting train Features for Client 9: 100%|██████████| 74/74 [03:41<00:00,  2.99s/it]


✅ Client 9 train Features Extracted: (2357, 512)


Extracting test Features for Client 9: 100%|██████████| 19/19 [00:54<00:00,  2.85s/it]


✅ Client 9 test Features Extracted: (590, 512)


Extracting train Features for Client 10: 100%|██████████| 74/74 [03:42<00:00,  3.01s/it]


✅ Client 10 train Features Extracted: (2357, 512)


Extracting test Features for Client 10: 100%|██████████| 19/19 [00:52<00:00,  2.78s/it]


✅ Client 10 test Features Extracted: (590, 512)


Extracting train Features for Client 11: 100%|██████████| 74/74 [03:40<00:00,  2.98s/it]


✅ Client 11 train Features Extracted: (2357, 512)


Extracting test Features for Client 11: 100%|██████████| 19/19 [00:52<00:00,  2.77s/it]


✅ Client 11 test Features Extracted: (590, 512)


Extracting train Features for Client 12: 100%|██████████| 74/74 [03:41<00:00,  2.99s/it]


✅ Client 12 train Features Extracted: (2357, 512)


Extracting test Features for Client 12: 100%|██████████| 19/19 [00:50<00:00,  2.68s/it]


✅ Client 12 test Features Extracted: (590, 512)


Extracting train Features for Client 13: 100%|██████████| 74/74 [03:38<00:00,  2.96s/it]


✅ Client 13 train Features Extracted: (2357, 512)


Extracting test Features for Client 13: 100%|██████████| 19/19 [00:53<00:00,  2.80s/it]


✅ Client 13 test Features Extracted: (590, 512)


Extracting train Features for Client 14: 100%|██████████| 74/74 [03:37<00:00,  2.94s/it]


✅ Client 14 train Features Extracted: (2357, 512)


Extracting test Features for Client 14: 100%|██████████| 19/19 [00:52<00:00,  2.76s/it]


✅ Client 14 test Features Extracted: (590, 512)


Extracting train Features for Client 15: 100%|██████████| 74/74 [03:38<00:00,  2.95s/it]


✅ Client 15 train Features Extracted: (2357, 512)


Extracting test Features for Client 15: 100%|██████████| 19/19 [00:51<00:00,  2.74s/it]


✅ Client 15 test Features Extracted: (590, 512)


Extracting train Features for Client 16: 100%|██████████| 74/74 [03:37<00:00,  2.95s/it]


✅ Client 16 train Features Extracted: (2357, 512)


Extracting test Features for Client 16: 100%|██████████| 19/19 [00:51<00:00,  2.71s/it]


✅ Client 16 test Features Extracted: (590, 512)


Extracting train Features for Client 17: 100%|██████████| 74/74 [03:42<00:00,  3.01s/it]


✅ Client 17 train Features Extracted: (2357, 512)


Extracting test Features for Client 17: 100%|██████████| 19/19 [00:52<00:00,  2.76s/it]


✅ Client 17 test Features Extracted: (590, 512)


Extracting train Features for Client 18: 100%|██████████| 74/74 [03:45<00:00,  3.04s/it]


✅ Client 18 train Features Extracted: (2357, 512)


Extracting test Features for Client 18: 100%|██████████| 19/19 [00:51<00:00,  2.72s/it]


✅ Client 18 test Features Extracted: (590, 512)


Extracting train Features for Client 19: 100%|██████████| 74/74 [03:43<00:00,  3.02s/it]


✅ Client 19 train Features Extracted: (2357, 512)


Extracting test Features for Client 19: 100%|██████████| 19/19 [00:52<00:00,  2.78s/it]


✅ Client 19 test Features Extracted: (590, 512)


Extracting train Features for Client 20: 100%|██████████| 74/74 [03:39<00:00,  2.97s/it]


✅ Client 20 train Features Extracted: (2357, 512)


Extracting test Features for Client 20: 100%|██████████| 19/19 [00:52<00:00,  2.76s/it]

✅ Client 20 test Features Extracted: (590, 512)

✅ Feature Extraction Complete!
Client 1: Train Shape = (2357, 512), Test Shape = (590, 512)
Client 2: Train Shape = (2357, 512), Test Shape = (590, 512)
Client 3: Train Shape = (2357, 512), Test Shape = (590, 512)
Client 4: Train Shape = (2357, 512), Test Shape = (590, 512)
Client 5: Train Shape = (2357, 512), Test Shape = (590, 512)
Client 6: Train Shape = (2357, 512), Test Shape = (590, 512)
Client 7: Train Shape = (2357, 512), Test Shape = (590, 512)
Client 8: Train Shape = (2357, 512), Test Shape = (590, 512)
Client 9: Train Shape = (2357, 512), Test Shape = (590, 512)
Client 10: Train Shape = (2357, 512), Test Shape = (590, 512)
Client 11: Train Shape = (2357, 512), Test Shape = (590, 512)
Client 12: Train Shape = (2357, 512), Test Shape = (590, 512)
Client 13: Train Shape = (2357, 512), Test Shape = (590, 512)
Client 14: Train Shape = (2357, 512), Test Shape = (590, 512)
Client 15: Train Shape = (2357, 512), Test Shape = (590, 512)




In [7]:
from collections import Counter

def print_class_distribution(dataset, client_id):
    labels = [label for _, label in dataset.dataset.samples]  # Extract all labels from original dataset
    label_counts = Counter([labels[idx] for idx in dataset.indices])  # Count labels for this subset
    print(f"📊 Client {client_id} Class Distribution: {dict(label_counts)}")

# Print for both training and testing datasets
for i, dataset in enumerate(train_loaders):  # Change to `train_loaders`
    print(f"🔹 Training Data for Client {i+1}:")
    print_class_distribution(dataset.dataset, i+1)

for i, dataset in enumerate(test_loaders):  # Change to `test_loaders`
    print(f"🔹 Testing Data for Client {i+1}:")
    print_class_distribution(dataset.dataset, i+1)

🔹 Training Data for Client 1:
📊 Client 1 Class Distribution: {1: 342, 3: 397, 0: 412, 4: 410, 2: 398, 5: 398}
🔹 Training Data for Client 2:
📊 Client 2 Class Distribution: {4: 406, 3: 407, 2: 404, 5: 394, 0: 407, 1: 339}
🔹 Training Data for Client 3:
📊 Client 3 Class Distribution: {4: 383, 5: 395, 2: 416, 0: 410, 1: 369, 3: 384}
🔹 Training Data for Client 4:
📊 Client 4 Class Distribution: {1: 341, 4: 402, 2: 418, 5: 407, 0: 398, 3: 391}
🔹 Training Data for Client 5:
📊 Client 5 Class Distribution: {0: 399, 5: 406, 1: 355, 2: 397, 3: 393, 4: 407}
🔹 Training Data for Client 6:
📊 Client 6 Class Distribution: {3: 393, 5: 407, 2: 403, 0: 405, 4: 401, 1: 348}
🔹 Training Data for Client 7:
📊 Client 7 Class Distribution: {5: 400, 0: 401, 3: 401, 1: 362, 4: 403, 2: 390}
🔹 Training Data for Client 8:
📊 Client 8 Class Distribution: {0: 394, 3: 408, 4: 403, 2: 383, 1: 357, 5: 412}
🔹 Training Data for Client 9:
📊 Client 9 Class Distribution: {4: 406, 1: 359, 2: 399, 3: 395, 0: 394, 5: 404}
🔹 Training