In [1]:
import numpy as np

In [2]:
from src.lib.model import model
from src.lib.activation_functions import ReLu, softmax
from src.lib.layers import Dense, flatten, BatchNorm1D, Conv2D, MaxPooling2D
from src.lib.loss_functions import CrossEntropy
from src.lib.trainer import trainer
from src.lib.optimizers import Adam, SGD

In [3]:

from tensorflow.keras.datasets import fashion_mnist

num_classes = 10
# Chargement
(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()

# Normaliser
x_train = x_train / 255.0
x_test = x_test / 255.0

def one_hot_encode(y, num_classes):
    return np.eye(num_classes)[y]

# Utilisation
y_train_encoded = one_hot_encode(y_train, num_classes)
y_test_encoded = one_hot_encode(y_test, num_classes)

print(y_train_encoded.shape)  # (60000, 10)
print(y_test_encoded.shape)   # (10000, 10)


(60000, 10)
(10000, 10)


In [4]:
fashion_mnist_labels = {
    0: "T-shirt/top",
    1: "Trouser",
    2: "Pullover",
    3: "Dress",
    4: "Coat",
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag",
    9: "Ankle boot"
}


In [5]:
x_train = x_train[:, np.newaxis, :, :]
x_test = x_test[:, np.newaxis, :, :]

In [6]:
optimizer = Adam(learning_rate=0.001)

input_shape = x_train.shape

model_mnist = model([
    Conv2D(out_channels=8, kernel_size=3, stride=1, padding=1),
    ReLu(), 
    
    flatten(), 
    
    Dense(128),
    BatchNorm1D(), 
    ReLu(), 
    
    Dense(10), 
    
    CrossEntropy(l2_lambda=0)
], 
input_shape=input_shape)


trainer_mnist = trainer(model_mnist, x_train, y_train_encoded, custom_train_test_set=(x_train, y_train_encoded, x_test, y_test_encoded))
trainer_mnist.train(nb_epochs = 20, optimizer=optimizer, batch_size = 512)


(1, 28, 28)
(8, 28, 28)
(8, 28, 28)
(6272,)
(128,)
(128,)
(128,)
(10,)
0
2.6573632
0.1009
1
0.5095073
0.8196
2
0.36678398
0.878
3
0.32015944
0.8854
4
0.29134396
0.8957
5
0.3092122
0.8901
6
0.31009218
0.8931
7
0.29489017
0.8981
8
0.29435804
0.9057
9
0.3053244
0.9056
10
0.32968414
0.8963
11
0.343396
0.9011
12
0.33109507
0.9051
13
0.33941552
0.9052
14
0.35621095
0.9019
15
0.3576322
0.9047
16
0.37041178
0.9042
17
0.38616985
0.9037
18
0.38825038
0.9027
19
0.4000115
0.9037
20
0.41150165
0.9056


(0.41150165, 0.9056)