In [2]:
import torch
import torchvision

torch.cuda.is_available()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.cuda.get_device_name(0)

'NVIDIA GeForce RTX 3090'

In [26]:
mnist = torchvision.datasets.MNIST('./data/', download=True)

mnist_train = [mnist[i] for i in range(50000)]
mnist_test = [mnist[i + 50000] for i in range(10000)]

In [27]:
import numpy as np
print(mnist[0])
img = np.array(mnist[0][0])
print(img.max(), img.min(), img.shape)

(<PIL.Image.Image image mode=L size=28x28 at 0x24E4CA17A00>, 5)
255 0 (28, 28)


In [28]:
y_counts = np.zeros((10,))
for _, y in mnist_train:
    y_counts[y] += 1

print(y_counts)

y_counts = np.zeros((10,))
for _, y in mnist_test:
    y_counts[y] += 1

print(y_counts)

[4932. 5678. 4968. 5101. 4859. 4506. 4951. 5175. 4842. 4988.]
[ 991. 1064.  990. 1030.  983.  915.  967. 1090. 1009.  961.]


In [29]:
y_idx_list = [[] for _ in range(10)]
for idx, data in enumerate(mnist_train):
    y_idx_list[data[1]].append(idx)

import json
with open('./data/mnist_class_idx.txt', 'w+') as jfile:
    json.dump(y_idx_list, jfile)

In [34]:
def sample_mnist(num_per_class=100):
    Xs = []
    Ys = []
    with open('./data/mnist_class_idx.txt') as jfile:
        indices_list = json.load(jfile)
    for class_idx in range(len(indices_list)):
        X = np.zeros((num_per_class, 28, 28))
        Y = np.zeros((num_per_class,))
        sample_indices = np.random.choice(indices_list[class_idx], size=num_per_class, replace=False)
        for X_idx, sample_idx in enumerate(sample_indices):
            cur_X, cur_Y = mnist_train[sample_idx]
            X[X_idx] = np.array(cur_X)
            Y[X_idx] = cur_Y
        Xs.append(X)
        Ys.append(Y)
    X_raw = np.concatenate(Xs, axis=0)
    Y_raw = np.concatenate(Ys, axis=0)
    idx_arr = np.arange(X_raw.shape[0])
    np.random.shuffle(idx_arr)
    return X_raw[idx_arr], Y_raw[idx_arr]

In [35]:
X, Y = sample_mnist(10)
print(X[0], Y)

[[  0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.
    0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.]
 [  0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.
    0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.]
 [  0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.
    0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.]
 [  0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.
    0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.]
 [  0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   2.  31.
  130. 222. 255. 255. 215.  86.   0.   0.   0.   0.   0.   0.   0.   0.]
 [  0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   2.  92. 156. 253.
  253. 253. 253. 253. 253. 213.  87.   0.   0.   0.   0.   0.   0.   0.]
 [  0.   0.   0.   0.   0.   0.   0.   0.   0.  24. 156. 253. 253. 253.
  253. 240. 240. 253. 253. 253. 213.  17.   0.   0.   0.  

In [36]:
X_subset, Y_subset = sample_mnist(100)
with open('./data/mnist_train_100.npz', 'wb+') as data_file:
    np.savez(data_file, X=X_subset, Y=Y_subset)

In [37]:
X_subset, Y_subset = sample_mnist(300)
with open('./data/mnist_train_300.npz', 'wb+') as data_file:
    np.savez(data_file, X=X_subset, Y=Y_subset)

In [38]:
X_subset, Y_subset = sample_mnist(500)
with open('./data/mnist_train_500.npz', 'wb+') as data_file:
    np.savez(data_file, X=X_subset, Y=Y_subset)

In [39]:
X_subset, Y_subset = sample_mnist(1000)
with open('./data/mnist_train_1000.npz', 'wb+') as data_file:
    np.savez(data_file, X=X_subset, Y=Y_subset)

In [40]:
X_subset, Y_subset = sample_mnist(2500)
with open('./data/mnist_train_2500.npz', 'wb+') as data_file:
    np.savez(data_file, X=X_subset, Y=Y_subset)

In [42]:
X = np.zeros((10000, 28, 28))
Y = np.zeros((10000))

for idx, data in enumerate(mnist_test):
    x_raw, y_raw = data
    X[idx] = np.array(x_raw)
    Y[idx] = y_raw

idx_arr = np.arange(X.shape[0])
np.random.shuffle(idx_arr)
X_shuffle = X[idx_arr]
Y_shuffle = Y[idx_arr]

with open('./data/mnist_test.npz', 'wb+') as data_file:
    np.savez(data_file, X=X_shuffle, Y=Y_shuffle)

In [43]:
X = np.zeros((50000, 28, 28))
Y = np.zeros((50000))

for idx, data in enumerate(mnist_train):
    x_raw, y_raw = data
    X[idx] = np.array(x_raw)
    Y[idx] = y_raw

idx_arr = np.arange(X.shape[0])
np.random.shuffle(idx_arr)
X_shuffle = X[idx_arr]
Y_shuffle = Y[idx_arr]

with open('./data/mnist_train_full.npz', 'wb+') as data_file:
    np.savez(data_file, X=X_shuffle, Y=Y_shuffle)

In [8]:
from data import load_mnist

X, Y = load_mnist(100)
print(X.shape, Y.shape)

torch.Size([1000, 28, 28]) torch.Size([1000])


In [4]:
X, Y = load_mnist(100, True)
print(X.shape, Y.shape)

torch.Size([1000, 784]) torch.Size([1000])


In [7]:
X, Y = load_mnist(0, False)
print(X.shape, Y.shape)

torch.Size([50000, 28, 28]) torch.Size([50000])
