In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split, Subset

# Define transformations for the data
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Download and load the training dataset
trainset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)

# Split the dataset into train and validation sets
num_train = int(len(trainset) * 0.8)  # 80% of data for training
num_val = len(trainset) - num_train   # 20% of data for validation
trainset, valset = random_split(trainset, [num_train, num_val])

# Define function for IID sharding
def iid_shard(dataset, num_clients):
    shard_size = len(dataset) // num_clients
    shards = [Subset(dataset, range(i * shard_size, (i + 1) * shard_size)) for i in range(num_clients)]
    return shards

# Define the number of clients
num_clients = 100

# Create IID shards for each client
shards = iid_shard(trainset, num_clients)

# DataLoader for the shards of each client
trainloaders = [DataLoader(shard, batch_size=4, shuffle=True, num_workers=2) for shard in shards]

# Download and load the test dataset
testset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)
testloader = DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)


Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./data/cifar-100-python.tar.gz


100%|██████████| 169001437/169001437 [00:10<00:00, 15666511.06it/s]


Extracting ./data/cifar-100-python.tar.gz to ./data
Files already downloaded and verified


In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split, SubsetRandomSampler
import numpy as np

# Define transformations for the data
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Download and load the training dataset
trainset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)

# Shuffle the dataset
indices = np.arange(len(trainset))
np.random.shuffle(indices)

# Split the dataset into train and validation sets
num_train = int(len(trainset) * 0.8)  # 80% of data for training
num_val = len(trainset) - num_train   # 20% of data for validation

# Use SubsetRandomSampler to split the dataset
train_sampler = SubsetRandomSampler(indices[:num_train])
val_sampler = SubsetRandomSampler(indices[num_train:])

# DataLoader for the training and validation sets
trainloader = DataLoader(trainset, batch_size=4, sampler=train_sampler, num_workers=2)
valloader = DataLoader(trainset, batch_size=4, sampler=val_sampler, num_workers=2)

# Define function for sharding with Nc
def shard_with_Nc(dataset, num_clients, Nc):
    num_classes = len(dataset.classes)
    shards_per_class = num_clients * Nc  # Total number of shards per class
    total_shards = num_classes * shards_per_class

    # Assign shard numbers to each class
    shard_numbers = np.tile(np.arange(1, shards_per_class + 1), num_classes)

    # Shuffle the shard numbers
    np.random.shuffle(shard_numbers)

    # Distribute the shards among the clients
    shards_per_client = total_shards // num_clients
    shards_per_client_list = [shard_numbers[i * shards_per_client:(i + 1) * shards_per_client] for i in range(num_clients)]

    # Create DataLoader for each client
    client_loaders = []
    for client_shards in shards_per_client_list:
        indices = []
        for shard_num in client_shards:
            # Find indices corresponding to the shard number for each class
            indices.extend(np.where(shard_numbers == shard_num)[0])

        # Create SubsetRandomSampler for client
        sampler = SubsetRandomSampler(indices)
        loader = DataLoader(dataset, batch_size=4, sampler=sampler, num_workers=2)
        client_loaders.append(loader)

    return client_loaders

# Define the number of clients and Nc (number of classes per client)
num_clients = 100
Nc = 5

# Create client loaders using sharding with Nc
client_loaders = shard_with_Nc(trainset, num_clients, Nc)

# Download and load the test dataset
testset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)
testloader = DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


In [11]:
def inspect_dataloader(dataloader):
    # Get the first batch to inspect its structure
    dataiter = iter(dataloader)
    images, labels = next(dataiter)

    # Print batch size and data shape
    print(f"Batch Size: {len(images)}")
    print(f"Data Shape: {images.shape}")

    # Print label details
    print("Labels:")
    for i, label in enumerate(labels):
        print(f"  Sample {i + 1}: {label}")


for index, client in enumerate(trainloaders):
  print(f'client: {index}')
  inspect_dataloader(client)



client: 0
Batch Size: 4
Data Shape: torch.Size([4, 3, 32, 32])
Labels:
  Sample 1: 2
  Sample 2: 65
  Sample 3: 10
  Sample 4: 56
client: 1
Batch Size: 4
Data Shape: torch.Size([4, 3, 32, 32])
Labels:
  Sample 1: 51
  Sample 2: 63
  Sample 3: 37
  Sample 4: 80
client: 2
Batch Size: 4
Data Shape: torch.Size([4, 3, 32, 32])
Labels:
  Sample 1: 91
  Sample 2: 53
  Sample 3: 66
  Sample 4: 38
client: 3
Batch Size: 4
Data Shape: torch.Size([4, 3, 32, 32])
Labels:
  Sample 1: 64
  Sample 2: 66
  Sample 3: 81
  Sample 4: 0
client: 4
Batch Size: 4
Data Shape: torch.Size([4, 3, 32, 32])
Labels:
  Sample 1: 97
  Sample 2: 78
  Sample 3: 33
  Sample 4: 88
client: 5
Batch Size: 4
Data Shape: torch.Size([4, 3, 32, 32])
Labels:
  Sample 1: 10
  Sample 2: 6
  Sample 3: 81
  Sample 4: 66
client: 6
Batch Size: 4
Data Shape: torch.Size([4, 3, 32, 32])
Labels:
  Sample 1: 4
  Sample 2: 39
  Sample 3: 65
  Sample 4: 94
client: 7
Batch Size: 4
Data Shape: torch.Size([4, 3, 32, 32])
Labels:
  Sample 1: 81
 

In [14]:
for index, client in enumerate(client_loaders):
  print(f'client: {index}')
  inspect_dataloader(client)

client: 0


  self.pid = os.fork()


Batch Size: 4
Data Shape: torch.Size([4, 3, 32, 32])
Labels:
  Sample 1: 53
  Sample 2: 33
  Sample 3: 19
  Sample 4: 51
client: 1
Batch Size: 4
Data Shape: torch.Size([4, 3, 32, 32])
Labels:
  Sample 1: 76
  Sample 2: 20
  Sample 3: 19
  Sample 4: 46
client: 2
Batch Size: 4
Data Shape: torch.Size([4, 3, 32, 32])
Labels:
  Sample 1: 15
  Sample 2: 67
  Sample 3: 39
  Sample 4: 83
client: 3
Batch Size: 4
Data Shape: torch.Size([4, 3, 32, 32])
Labels:
  Sample 1: 64
  Sample 2: 97
  Sample 3: 39
  Sample 4: 59
client: 4
Batch Size: 4
Data Shape: torch.Size([4, 3, 32, 32])
Labels:
  Sample 1: 15
  Sample 2: 63
  Sample 3: 23
  Sample 4: 77
client: 5
Batch Size: 4
Data Shape: torch.Size([4, 3, 32, 32])
Labels:
  Sample 1: 26
  Sample 2: 45
  Sample 3: 66
  Sample 4: 81
client: 6
Batch Size: 4
Data Shape: torch.Size([4, 3, 32, 32])
Labels:
  Sample 1: 32
  Sample 2: 46
  Sample 3: 20
  Sample 4: 87
client: 7
Batch Size: 4
Data Shape: torch.Size([4, 3, 32, 32])
Labels:
  Sample 1: 9
  Sample

In [15]:
inspect_dataloader(testloader)

  self.pid = os.fork()
  self.pid = os.fork()


Batch Size: 4
Data Shape: torch.Size([4, 3, 32, 32])
Labels:
  Sample 1: 49
  Sample 2: 33
  Sample 3: 72
  Sample 4: 51
