In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds

## 1. 전처리 이전의 데이터

In [2]:
train_ds = tfds.load(name='mnist',
                     shuffle_files=True,
                     as_supervised=True,
                     split='train',
                     batch_size=4)

In [3]:
for images, labels in train_ds:
    print(images.shape)
    print(images.dtype)  # 데이터의 dtype은 float32여야 하지만, 여기서는 int이다.
    print(tf.reduce_max(images))  # 데이터의 최대값이 255 -> 이러면 weight의 크기가 너무 커져서, 학습이 잘 이루어지지 않는다
    break

(4, 28, 28, 1)
<dtype: 'uint8'>
tf.Tensor(255, shape=(), dtype=uint8)


##  2. map(), lambda 알아보기

In [4]:
a = [1, 2, 3, 4, 5]
def double(in_val):
    return 2*in_val

doubled = list(map(double, a))
print(doubled)

[2, 4, 6, 8, 10]


In [5]:
doubled2 = list(map(lambda x : 2*x, a))
print(doubled2)

[2, 4, 6, 8, 10]


### 3. tf.cast() 알아보기

In [6]:
t1 = tf.constant([1, 2, 3, 4, 5])
print(t1.dtype)
t2 = tf.cast(t1, tf.float32)
print(t2.dtype)

<dtype: 'int32'>
<dtype: 'float32'>


## 4. Dataset.map() 활용

In [7]:
def standardization(images, labels):
    images = tf.cast(images, tf.float32) / 255.
    return [images, labels] 

In [8]:
train_ds_iter = iter(train_ds)
images, labels = next(train_ds_iter)
print(images.dtype, tf.reduce_max(images))

<dtype: 'uint8'> tf.Tensor(255, shape=(), dtype=uint8)


In [9]:
train_ds = train_ds.map(standardization)

train_ds_iter = iter(train_ds)
images, labels = next(train_ds_iter)
print(images.dtype, tf.reduce_max(images))

<dtype: 'float32'> tf.Tensor(1.0, shape=(), dtype=float32)


## 5. 실제 활용 방안

In [10]:
def mnist_data_loader():
    def standardization(images, labels):
        images = tf.cast(images, tf.float32) / 255.
        return [images, labels] 

    train_ds, test_ds = tfds.load(name='mnist',
                        shuffle_files=True,
                        as_supervised=True,
                        split=['train', 'test'],
                        batch_size=4)
    train_ds = train_ds.map(standardization)
    test_ds = test_ds.map(standardization)
    return train_ds, test_ds

train_ds, test_ds = mnist_data_loader()