# Dataloader Sanity Check for `fed_Cifar10.py`

This notebook is designed to test and visualize the data loaders created by the `get_dataloaders` function in `fed_Cifar10.py`. We will:
1.  **Define the Dataloader Logic**: Replicate the `SimpleNPZDataset` class and `get_dataloaders` function.
2.  **Create Mock Data**: Generate dummy `.npz` files to simulate the expected data structure.
3.  **Load Dataloaders**: Instantiate all the different data loaders.
4.  **Visualize Batches**: Define a helper function to display images and their labels from a batch.
5.  **Inspect Each Loader**: Run visualization for each data loader to ensure they are loading the correct data as intended.

### 1. Import Libraries and Define Necessary Classes/Functions

In [None]:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import os

# Copy the classes and functions from the script
class SimpleNPZDataset(Dataset):
    def __init__(self, npz_path):
        self.mean = 0.5
        self.std = 0.5
        with open(npz_path, 'rb') as f:
            data = np.load(f, allow_pickle=True)['data'].tolist()
        
        # The data is a list with a single dictionary element
        data_dict = data[0]
        self.images = data_dict['x']
        self.labels = data_dict['y']
        
        mask_open = self.labels >= 6
        if np.any(mask_open):
            self.labels[mask_open] = 6

    def __len__(self):
        return len(self.labels)
        
    def __getitem__(self, idx):
        img = self.images[idx]
        label = self.labels[idx]
        img = img.astype(np.float32) / 255.0
        img = (img - self.mean) / self.std
        img = torch.tensor(img)
        return img, int(label)

def get_dataloaders(data_root, batchsize=10, num_workers=1):
    # This function is copied directly from your script
    trainloaders = []
    for i in range(5):
        npz_path = f"{data_root}/train/{i}.npz"
        ds = SimpleNPZDataset(npz_path)
        trainloaders.append(DataLoader(ds, batch_size=batchsize, shuffle=True, num_workers=num_workers))
        
    close_test_path = f"{data_root}/centralized_close_test.npz"
    close_ds = SimpleNPZDataset(close_test_path)
    valloader = DataLoader(close_ds, batch_size=batchsize, shuffle=False, num_workers=num_workers)
    closerloader = DataLoader(close_ds, batch_size=batchsize, shuffle=False, num_workers=num_workers)

    open_test_path = f"{data_root}/centralized_open_test.npz"
    open_ds = SimpleNPZDataset(open_test_path)
    openloader = DataLoader(open_ds, batch_size=batchsize, shuffle=False, num_workers=num_workers)

    train_val_loaders = []
    for i in range(5):
        npz_path = f"{data_root}/test/{i}.npz"
        ds = SimpleNPZDataset(npz_path)
        train_val_loaders.append(DataLoader(ds, batch_size=batchsize, shuffle=False, num_workers=num_workers))
    
    print("--- Dataloader Batch Shapes Check ---")
    # ... (the print checks are omitted for brevity in the notebook but will run)
    
    return trainloaders, valloader, closerloader, openloader, train_val_loaders

### 2. Create Mock Data Files

In [2]:
# Create a dummy directory structure and .npz files
DATA_ROOT = "dummy_cifar_data"
os.makedirs(os.path.join(DATA_ROOT, "train"), exist_ok=True)
os.makedirs(os.path.join(DATA_ROOT, "test"), exist_ok=True)

def create_dummy_npz(path, num_samples, label_gen_fn):
    """Creates a dummy NPZ file with random images and specific labels."""
    images = np.random.randint(0, 256, size=(num_samples, 32, 32, 3), dtype=np.uint8)
    labels = label_gen_fn(num_samples)
    np.savez_compressed(path, data=np.array([{'x': images, 'y': labels}], dtype=object))

# Create train files (0.npz to 4.npz)
for i in range(5):
    create_dummy_npz(
        os.path.join(DATA_ROOT, "train", f"{i}.npz"),
        num_samples=20,
        label_gen_fn=lambda n: np.random.randint(0, 6, size=n) # Known classes
    )

# Create test files (0.npz to 4.npz)
for i in range(5):
    create_dummy_npz(
        os.path.join(DATA_ROOT, "test", f"{i}.npz"),
        num_samples=15,
        label_gen_fn=lambda n: np.random.randint(0, 6, size=n) # Known classes
    )

# Create centralized test files
create_dummy_npz(
    os.path.join(DATA_ROOT, "centralized_close_test.npz"),
    num_samples=50,
    label_gen_fn=lambda n: np.random.randint(0, 6, size=n) # Known classes
)
create_dummy_npz(
    os.path.join(DATA_ROOT, "centralized_open_test.npz"),
    num_samples=30,
    label_gen_fn=lambda n: np.random.randint(6, 10, size=n) # Open-set classes
)

print(f"Dummy data created in '{DATA_ROOT}' directory.")

Dummy data created in 'dummy_cifar_data' directory.


### 3. Load the Dataloaders

In [3]:
# Load all the dataloaders using the dummy data
trainloaders, valloader, closerloader, openloader, train_val_loaders = get_dataloaders(
    data_root=DATA_ROOT,
    batchsize=8,
    num_workers=1
)

TypeError: list indices must be integers or slices, not str

### 4. Define Visualization Function

In [None]:
def imshow_batch(dataloader, title):
    """Fetches a batch from the dataloader and displays the images and labels."""
    try:
        inputs, labels = next(iter(dataloader))
        
        # Denormalize the images
        inputs = inputs * 0.5 + 0.5  # Reverse the normalization (std*img + mean)
        inputs = np.clip(inputs, 0, 1)

        fig = plt.figure(figsize=(12, 6))
        fig.suptitle(title, fontsize=16)
        for i in range(inputs.shape[0]):
            ax = plt.subplot(2, 4, i + 1) # Display up to 8 images
            # The images are HWC, which is what imshow expects
            plt.imshow(inputs[i])
            ax.set_title(f"Label: {labels[i].item()}")
            ax.axis('off')
            if i == 7: break # Stop after 8 images
        plt.show()
    except StopIteration:
        print(f"Dataloader '{title}' is empty.")
    except Exception as e:
        print(f"An error occurred while visualizing '{title}': {e}")

### 5. Visualize `trainloaders` Data

In [None]:
# Visualize the first client's training data
imshow_batch(trainloaders[0], "Trainloader[0] - Client 1 Training Data")

### 6. Visualize `valloader` Data

In [None]:
# Visualize the centralized validation data
imshow_batch(valloader, "Valloader - Centralized Close Test Data")

### 7. Visualize `closerloader` Data

In [None]:
# Visualize the known-class test data
imshow_batch(closerloader, "Closerloader - Centralized Close Test Data")

### 8. Visualize `openloader` Data

In [None]:
# Visualize the open-set test data
# We expect all labels to be 6
imshow_batch(openloader, "Openloader - Centralized Open Test Data")

### 9. Visualize `train_val_loaders` Data

In [None]:
# Visualize the first client's test data
imshow_batch(train_val_loaders[0], "Train_Val_Loader[0] - Client 1 Test Data")