In [1]:
import numpy as np
import tensorflow as tf
import logging
import os
import random
import torch
from abc import ABC, abstractmethod
import time
from sklearn.manifold import TSNE
from torch.utils.data import Dataset
from torchvision import datasets, transforms
import keras 
from keras.datasets import mnist
from keras.models import Model
from keras.layers import Dense, Input
from keras.layers import Conv2D, MaxPooling2D, Dropout, Flatten
from keras import backend as k
from torch.utils.data import Subset

In [2]:
# Data Stuff
# Set download to False if already downloaded
"""""
transform = transforms.Compose([transforms.ToTensor()])

# Load the MNIST dataset 
mnist_trainset = datasets.MNIST(root='./data', train=True, download=False, transform=transform)
mnist_testset = datasets.MNIST(root='./data', train=False, download=False, transform=transform)

print(f"Data: {mnist_trainset.data.shape}, Targets: {mnist_trainset.targets.shape}")
"""""
# test data
(x_train, y_train), (x_test, y_test) = mnist.load_data()


In [3]:
img_rows, img_cols=28, 28

if k.image_data_format() == 'channels_first':
   x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
   x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
   inpx = (1, img_rows, img_cols)

else:
   x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
   x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
   inpx = (img_rows, img_cols, 1)

x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255

In [4]:
print(x_train.shape)
print(y_train.shape)
print(type(y_train))
K = np.unique(y_train)
print(K)
print(inpx)

(60000, 28, 28, 1)
(60000,)
<class 'numpy.ndarray'>
[0 1 2 3 4 5 6 7 8 9]
(28, 28, 1)


In [5]:
# Custom dataset class that wraps x_train and y_train
class MNISTDataset(Dataset):
    def __init__(self, data, targets):
        # data: the images (flattened or original, depending on your requirement)
        # targets: the labels (e.g., the digit for each image)
        self.data = data
        self.targets = targets

    def __len__(self):
        # Return the number of samples in the dataset
        return len(self.data)

    def __getitem__(self, idx):
        # Returns a single sample and its corresponding label
        return self.data[idx], self.targets[idx]

In [6]:
train_dataset = MNISTDataset(x_train, y_train)

# Help Function

In [7]:
def balanced_dirichlet_partition(dataset, partitions_number=10, alpha=0.5, seed=42):
    """
    Partition the dataset into multiple subsets using a Dirichlet distribution,
    ensuring that each partition contains samples from every class.

    Args:
        dataset (torch.utils.data.Dataset): The dataset to partition. It should have 'data' and 'targets' attributes.
        partitions_number (int): Number of partitions to create.
        alpha (float): The concentration parameter of the Dirichlet distribution. A lower alpha value leads to more imbalanced partitions.
        seed (int): Random seed for reproducibility.

    Returns:
        dict: A dictionary where keys are partition indices (0 to partitions_number-1)
              and values are lists of indices corresponding to the samples in each partition.
    """
    np.random.seed(seed)
    
    # Extract targets (labels) from the dataset
    y_train = dataset.targets
    
    # Number of classes in the dataset
    num_classes = len(np.unique(y_train))
    
    # Initialize the map that will store the indices for each partition
    net_dataidx_map = {}
    
    # Initialize lists to store the indices for each class
    class_indices = {k: np.where(y_train == k)[0] for k in range(num_classes)}
    
    # Shuffle the indices within each class to ensure random distribution
    for k in class_indices.keys():
        np.random.shuffle(class_indices[k])
    
    # Ensure that each partition gets at least one sample from each class
    min_size = 10  # Ensuring each class has at least 10 samples in each partition
    idx_batch = [[] for _ in range(partitions_number)]

    # Assign at least `min_size` samples from each class to every partition
    for k in range(num_classes):
        # Split the class indices equally across the partitions
        split = np.array_split(class_indices[k], partitions_number)
        for i in range(partitions_number):
            idx_batch[i].extend(split[i][:min_size])

        # Remove the samples that were assigned equally
        class_indices[k] = class_indices[k][min_size*partitions_number:]

    # Now distribute the remaining samples using the Dirichlet distribution
    for k in range(num_classes):
        remaining_indices = class_indices[k]
        proportions = np.random.dirichlet(np.repeat(alpha, partitions_number))
        proportions = (np.cumsum(proportions) * len(remaining_indices)).astype(int)[:-1]
        split_remaining = np.split(remaining_indices, proportions)
        
        for i in range(partitions_number):
            idx_batch[i].extend(split_remaining[i])

    # Shuffle the indices within each partition
    for i in range(partitions_number):
        np.random.shuffle(idx_batch[i])
        net_dataidx_map[i] = idx_batch[i]

    return net_dataidx_map


# Initial Setup

In [8]:
# Partition 

partitioned_data = balanced_dirichlet_partition(train_dataset, partitions_number=10, alpha=0.5)
partition_0_dataset = Subset(train_dataset, partitioned_data[0])

In [21]:
# Check sizes

for i in partitioned_data:
    print(len(partitioned_data[i]))

8115
6299
7690
3996
5373
5975
6862
3842
6307
5541


In [10]:
# Reformatting Subsets

x_train1 = []
y_train1 = []

for img, label in partition_0_dataset:
    x_train1.append(img)
    y_train1.append(label)

x_train1 = np.array(x_train1)
y_train1 = np.array(y_train1)

# Check shapes to confirm they match the original structure
print("x_train1 shape:", x_train1.shape)  # Should match the shape (num_samples, 784) if flattened
print("y_train1 shape:", y_train1.shape)  # Should match the shape (num_samples,)

x_train1 shape: (8115, 28, 28, 1)
y_train1 shape: (8115,)


In [11]:
L = np.unique(y_train1)
print(L)

[0 1 2 3 4 5 6 7 8 9]


In [12]:
# Convert the labels for model

y_train = keras.utils.to_categorical(y_train)

y_train1 = keras.utils.to_categorical(y_train1)


In [13]:

print(x_train1.shape)
print(y_train1.shape)


(8115, 28, 28, 1)
(8115, 10)


In [14]:
inpx = Input(shape=inpx)
layer1 = Conv2D(32, kernel_size=(3, 3), activation='relu')(inpx)
layer2 = Conv2D(64, (3, 3), activation='relu')(layer1)
layer3 = MaxPooling2D(pool_size=(3, 3))(layer2)
layer4 = Dropout(0.5)(layer3)
layer5 = Flatten()(layer4)
layer6 = Dense(250, activation='sigmoid')(layer5)
layer7 = Dense(10, activation='softmax')(layer6)


In [18]:
model = Model([inpx], layer7)
model.compile(optimizer=keras.optimizers.Adadelta(),
              loss=keras.losses.categorical_crossentropy,
              metrics=['accuracy'])

model.fit(x_train1, y_train1, epochs=12, batch_size=64)

Epoch 1/12
[1m127/127[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 25ms/step - accuracy: 0.1295 - loss: 2.7049
Epoch 2/12
[1m127/127[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 25ms/step - accuracy: 0.1311 - loss: 2.6580
Epoch 3/12
[1m127/127[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 24ms/step - accuracy: 0.1305 - loss: 2.6198
Epoch 4/12
[1m127/127[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 24ms/step - accuracy: 0.1344 - loss: 2.5679
Epoch 5/12
[1m127/127[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 24ms/step - accuracy: 0.1358 - loss: 2.5173
Epoch 6/12
[1m127/127[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 24ms/step - accuracy: 0.1319 - loss: 2.4605
Epoch 7/12
[1m127/127[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 24ms/step - accuracy: 0.1390 - loss: 2.4072
Epoch 8/12
[1m127/127[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 24ms/step - accuracy: 0.1442 - loss: 2.3512
Epoch 9/12
[1m127/127[0m [32m

<keras.src.callbacks.history.History at 0x2b2813b2fd0>

In [23]:
y_test = keras.utils.to_categorical(y_test)

score = model.evaluate(x_test, y_test, verbose=0)
print('loss=', score[0])
print('accuracy=', score[1])

loss= 2.291794538497925
accuracy= 0.09769999980926514
