[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/aldomunaretto/immune_deep_learning/blob/main/notebooks/01_intro_DL/01_keras_intro.ipynb)

# First Neural Network

In [None]:
# Import required libraries
from keras.models import Sequential
from keras.layers import Dense, Flatten, Dropout, Rescaling
from keras.datasets import mnist
import numpy as np
import matplotlib.pyplot as plt

# # Load the MNIST dataset
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# Normalize the input data
x_train, x_test = x_train / 255.0, x_test / 255.0

In [None]:
# Image example
plt.imshow(x_train[9])
plt.show()

In [None]:
# Define the model architecture
model = Sequential([
    Rescaling(1./255, input_shape=(28, 28, 1)), # Rescale the pixel values to the range [0,1]
    Flatten(input_shape=(28, 28)), # Flatten the 28x28 input image into a 1D array of 784 values
    Dense(128, activation='sigmoid'), # Dense layer with 128 neurons and sigmoid activation
    Dropout(0.2), # Dropout layer for regularization to prevent overfitting
    Dense(10, activation='softmax') # Output layer with 10 neurons (one for each class) and softmax activation
])


# Compile the model
model.compile(optimizer='adam', # Use the Adam optimizer
              loss='sparse_categorical_crossentropy', # Sparse categorical crossentropy loss for multi-class classification
              metrics=['accuracy']) # Track accuracy metric during training

# Train the model on the dataset
history = model.fit(x_train, y_train, epochs=10, validation_split=0.1)


In [None]:
# Evaluate the model on the test dataset
test_loss, test_accuracy = model.evaluate(x_test, y_test)

# Print the test loss and accuracy
print(f"Test loss: {test_loss}, Test accuracy: {test_accuracy}")

In [None]:
# Visualize the training history
plt.plot(history.history['accuracy'], label='Training accuracy')
plt.plot(history.history['val_accuracy'], label='Validation accuracy')
plt.title('Training and validation accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()