### Original MNIST image
28x28 grayscale image, Each pixel in the image is represented as a single integer between 0 (black) and 255 (white).

### transforms.ToTensors()
#### 4*4, single channel for demonstration purposes
before:
```python
[[128,  64,  32,  16],
 [255,  0,   128, 64],
 [32,   16,  255, 0],
 [128,  64,  32,  16]]
```
after:
```python
[[[0.5020, 0.2510, 0.1255, 0.0627],
  [1.0000, 0.0000, 0.5020, 0.2510],
  [0.1255, 0.0627, 1.0000, 0.0000],
  [0.5020, 0.2510, 0.1255, 0.0627]]]
```
#### 4*4, 3 channels
before:
```python
[
 [(128, 64, 32),  (64, 32, 16),  (32, 16, 8),   (16, 8, 4)],
 [(255, 128, 64), (0, 0, 0),     (128, 64, 32), (64, 32, 16)],
 [(32, 16, 8),   (16, 8, 4),    (255, 128, 64), (0, 0, 0)],
 [(128, 64, 32), (64, 32, 16),  (32, 16, 8),   (16, 8, 4)]
]
```

after:
```python
[
 [[0.5020, 0.2510, 0.1255, 0.0627],
  [1.0000, 0.0000, 0.5020, 0.2510],
  [0.1255, 0.0627, 1.0000, 0.0000],
  [0.5020, 0.2510, 0.1255, 0.0627]],

 [[0.2510, 0.1255, 0.0627, 0.0314],
  [0.5020, 0.0000, 0.2510, 0.1255],
  [0.0627, 0.0314, 0.5020, 0.0000],
  [0.2510, 0.1255, 0.0627, 0.0314]],

 [[0.1255, 0.0627, 0.0314, 0.0157],
  [0.2510, 0.0000, 0.1255, 0.0627],
  [0.0314, 0.0157, 0.2510, 0.0000],
  [0.1255, 0.0627, 0.0314, 0.0157]]
]
```

# Prepare data 

In [None]:
# Download MNIST training set
import torch

from torchvision import datasets, transforms

# Define a transform to normalize the data
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,))])

# Download and load the training data
trainset = datasets.MNIST('./MNIST_data/', download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

# Display random digit image
import matplotlib.pyplot as plt
import numpy as np
import torchvision

# Get one batch of images from the data loader
dataiter = iter(trainloader)
images, labels = next(dataiter)

# Function to un-normalize and display an image
def imshow(img):
    img = img / 2 + 0.5  # unnormalize
    plt.imshow(np.transpose(img, (1, 2, 0)))  # convert from Tensor image

# Display images
imshow(torchvision.utils.make_grid(images))

from torch import nn, optim

# Define the network architecture
model = nn.Sequential(nn.Linear(784, 128),
                      nn.ReLU(),
                      nn.Linear(128, 64),
                      nn.ReLU(),
                      nn.Linear(64, 10),
                      nn.LogSoftmax(dim=1))

# Use GPU if it's available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Train model

In [None]:
from torch import nn, optim

# Define the loss
criterion = nn.NLLLoss()

# Define the optimizer
optimizer = optim.SGD(model.parameters(), lr=0.003)

epochs = 5
for e in range(epochs):
    running_loss = 0
    for images, labels in trainloader:
        # Move images and labels to the device
        images, labels = images.to(device), labels.to(device)

        # Flatten MNIST images into a 784 long vector
        images = images.view(images.shape[0], -1)
    
        # Training pass
        optimizer.zero_grad()
        
        output = model(images)
        loss = criterion(output, labels)
        
        #This is where the model learns by backpropagating
        loss.backward()
        
        #And optimizes its weights here
        optimizer.step()
        
        running_loss += loss.item()
    else:
        print(f"Training loss: {running_loss/len(trainloader)}")

# Save the model
torch.save(model.state_dict(), 'mnist_model.pth')


# Test accuracy

In [None]:
# Download MNIST test set, apply transform
testset = datasets.MNIST('./MNIST_data/', download=True, train=False, transform=transform)

# Create a data loader for the test set
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True)

# Use the trained model, predict on the test set, and get the accuracy
correct = 0
total = 0

with torch.no_grad():
    for data in testloader:
        images, labels = data
        # Flatten the images into a 2D tensor
        images = images.view(images.shape[0], -1)
        # Move images and labels to the device
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the test images: %d %%' % (100 * correct / total))


# Visualization

In [None]:
import matplotlib.pyplot as plt

# Load model
model.load_state_dict(torch.load('mnist_model.pth'))

# Create lists to store images, true labels, and predicted labels
incorrect_images = []
incorrect_labels = []
incorrect_predictions = []

with torch.no_grad():
    for data in testloader:
        images, labels = data
        # Flatten the images into a 2D tensor
        images = images.view(images.shape[0], -1)
        # Move images and labels to the device
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)

        # Find incorrect predictions
        incorrect = (predicted != labels)
        if incorrect.any():
            incorrect_images.extend(images[incorrect].cpu().numpy())
            incorrect_labels.extend(labels[incorrect].cpu().numpy())
            incorrect_predictions.extend(predicted[incorrect].cpu().numpy())

        # Stop when we have enough images
        if len(incorrect_images) >= 5:
            break

# Now, incorrect_images, incorrect_labels, and incorrect_predictions contain the images, labels, and predictions
# for the first 5 incorrect predictions. You can use these to visualize the images.

# Create 1x5 grid for displaying images
fig, axes = plt.subplots(1, 5, figsize=(10,2))

for i, ax in enumerate(axes.flat):
    # Get image, true label, and prediction
    img = incorrect_images[i].reshape(28, 28)
    true_label = incorrect_labels[i]
    prediction = incorrect_predictions[i]

    # Display image
    ax.imshow(img, cmap='gray')

    # Remove axes
    ax.axis('off')

    # Set title as the true label and prediction
    ax.set_title(f'True: {true_label}, Pred: {prediction}')

plt.show()