# 4.1 ***Your model***
___
Well done, you've arrived here ! You now understand key concepts of neural networks and how they are trained, but you haven't created one yet...
Don't worry this final task will guide you in recreating a neural network trained to detect any handwritten digit on a 28 by 28 pixel image !

In [None]:
#import all the necessary libs :)

import torch
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import numpy as np
import time

First of all we need to setup the dataset, we tensorize and noramlize the dataset for simplification. *(You don't need to understand for now just load the dataset, but don't hesistate to ask !)*

In [None]:
# Just run this code 

# transfrom and normalize the data
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# load dataset
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# print info
print(f"Len train dataset : {len(train_dataset)}")
print(f"Len test  dataset : {len(test_dataset)}")

To understand what's inside this code you can try below to visualise some of the examples !

***Don't hesitate to change the NUMBER_OF_ELEMENTS enum to see mutliples examples or no***

In [None]:
# Visualisation of some element of the dataset

#TODO: Change the number of elements to display
NUMBER_OF_ELEMENTS = ...

def imshow(img):
    img = img * 0.5 + 0.5  # Denormalisation
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)), cmap='gray')
    plt.axis('off')
    plt.show()

train_loader_vis = torch.utils.data.DataLoader(train_dataset, batch_size=NUMBER_OF_ELEMENTS, shuffle=True)

# Random image 
dataiter = iter(train_loader_vis)
images, labels = next(dataiter)

imshow(torchvision.utils.make_grid(images))
print('Labels :', ' '.join(f'{labels[j].item()};' for j in range(NUMBER_OF_ELEMENTS)))

___
### Create your model ! 

Now it's the big work... Create your own model **from scratch** !!

***Don't hesitate to make it simple the first time and try to implement with more difficult architecture, think about a simple fully connected layer that start with a flatten layer in the first time and recreate more complex model after !***

To understand how CNN or FCN (Linear models) each work, please go check the bonus part [here](<../Part I - The Forward Pass/1.1 concept_of_neural_network.ipynb>)

Let's try to help you to create this model : 


In [None]:
#TODO: Setup the Learning rate
LEARNING_RATE = ...

class MNISTModel(nn.Module):
    def __init__(self):
        super(MNISTModel, self).__init__()

        self.flatten = nn.Flatten()
        self.fc1 = ... # 28*28 -> what you want
        self.fc2 = ...
        ...
        self.fc = ... # what you want -> 10

        self.loss = ...
        self.optimizer = ... # use self parameters and the learning rate
        self.relu = ...

        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.to(self.device)

    def forward(self, x):
        x = self.flatten(x)
        x = self.relu(self.fc1(x))
        x = ...
        ...
        return x


    def train_model(self, epochs, train_loader):
        self.train()  # Training mode

        for epoch in range(epochs):
            start_time = time.time()  # Start time of the epoch
            running_loss = 0.0
            total_batches = 0

            for i, data in enumerate(train_loader): # Enumerate the data, all the dataset
                inputs, labels = data
                inputs, labels = inputs.to(self.device), labels.to(self.device)
                #TODO: Implement the training loop
                
                # Gradient to zero
                # Forward pass
                # Loss calculation
                # Backward pass
                # Optimisation step


                running_loss += loss.item()
                total_batches += 1 # just help for print 
                # print every 8 mini-batches
                if (i + 1) % 8 == 0 or (i + 1) == len(train_loader):
                    print(f"\rEpochs {epoch + 1}/{epochs} | Lot {i + 1}/{len(train_loader)} | Loss : {loss.item():.4f}", end='')

            
            avg_loss = running_loss / len(train_loader)
            epoch_time = time.time() - start_time


            print(f"\nEpochs {epoch + 1}/{epochs} finish | Average Loss : {avg_loss:.4f} | Time : {epoch_time:.2f} seconds")

        # change the model_path if you want
        model_path = "mnist_model.pth"
        print('Training finished, saving model to :', model_path)
        torch.save(self.state_dict(), model_path)


    def test_model(self, test_loader):
        self.eval()  # Evaluation mode
        correct = 0
        total = 0
        with torch.no_grad():
            for data in test_loader:
                images, labels = data
                images, labels = images.to(self.device), labels.to(self.device)
                outputs = self(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        print(f'Accuracy of the model on {total} images is : {100 * correct / total:.2f}%')

    def load_weights(self, model_path):
        self.load_state_dict(torch.load(model_path))

### Initialize the Model

Initialize the model outside the training loop to load it only once. If you want to restart the training with random weights, you can restart this cell. Otherwise, the training will continue from the **`last loss value`**.

In [None]:
model = MNISTModel()

### Ready to Train Your Model!

If you’ve reached this point, you’re now ready to train your model. But before you begin, don’t forget to set up your **BATCH_SIZE** and **EPOCHS**. Let’s clarify what they mean:

- *Epochs* represent the number of times your entire dataset will be processed by the model during training. For instance, with a dataset of 60,000 images, setting ***EPOCHS=10*** means that these 60,000 images will be fed into the model 10 times in total, allowing the model to adjust its weights with each pass.


- *Batch size* is the number of images processed in each forward pass before calculating the loss and applying backpropagation to update the model's weights. For example, with ***BATCH_SIZE=64***, 64 images are fed into the model simultaneously. After calculating the loss for this batch, backpropagation adjusts the weights based on that batch’s results.

For each epoch, your model will process several *batches*, each containing a specified number of images (`batch_size`). The total number of batches per epoch can be calculated with the following formula:

$$
{\text{LOT}} = \frac{\text{len\_dataset}}{\text{batch\_size}}
$$


In [None]:
# loads the datasets, with a batch_size 
BATCH_SIZE = ...

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)


EPOCHS = ...  # Number of epochs

# Train the model
model.train_model(EPOCHS, train_loader)

Now time to play with other students ! Try to compare which have the best accuracy ! it will be the ***best model*** :)

*(try to go as close as 99%)* 

In [None]:
model.test_model(test_loader)

In [None]:
# work in progress

# import importlib
# import visualizer
# import threading
# # Recharger le module si nécessaire
# importlib.reload(visualizer)

# from visualizer import visualize_model

# flask_thread = threading.Thread(target=visualize_model, args=(model, 5001))
# flask_thread.start()

# print("Server is running on http://127.0.0.1:5001")