# PyTorch

Many of you might already be familiar with the basics of AI, a reminder always helps, especialy if you did not use the PyTorch framework before !

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

Let's create the simplest classifier for MNIST !

In [2]:
class SimpleClassifier(nn.Module):
    def __init__ (self):
        super(SimpleClassifier, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(28*28, 128),
            nn.ReLU(),
            nn.Linear(128, 10)
        )

    def forward(self, x):
        return self.model(x)

This very simple model can be represented like this (made the image myself, might not be the best but the spirit is here !):

<p align="center">
  <img src="./fc_mnist_simple.png" />
</p>

Now we simply declare the model, the optimizer (the way we will update weights through training) and the loss function (how we will compare the output and the expected classification)

In [3]:
model = SimpleClassifier()
# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

We will now import the MNIST dataset. In real life application, you may have to make this dataset yourself, which is in most cases, one of the hardest part...

In [4]:
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# Define transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))  # MNIST mean and std
])

# Load MNIST dataset
train_dataset = torchvision.datasets.MNIST(root='./data', 
                                           train=True, 
                                           transform=transform, 
                                           download=True)

# Create DataLoader
train_loader = DataLoader(dataset=train_dataset, 
                          batch_size=100, 
                          shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:02<00:00, 4167931.35it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 337261.87it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:01<00:00, 1063683.01it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 9896378.58it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw






### And now train the model !

We will do so by running 5 times the same loop (5 epochs) to see how the model adapts durring training.

The DataLoader stucture is organized in batches extracted from the raw dataset (here we have 600 batches of 100 sample)

We can also check out the sample shape (or size) and also see they are all the same!

In [21]:
print(len(train_loader))
# note that print(train_loader[0]) will not work as it is itrable but not indexable !
i = 0
for batch in train_loader:
    print(batch[0].shape)
    if i == 4 :
        break 
    i+=1

600
torch.Size([100, 1, 28, 28])
torch.Size([100, 1, 28, 28])
torch.Size([100, 1, 28, 28])
torch.Size([100, 1, 28, 28])
torch.Size([100, 1, 28, 28])


### We now train our model

We reshpe the images, perform model inferce and compute loss/weights gradients. We then uypdate the weights and go to the next batch (and after that, redo an epoch, until the end)

In [6]:
# Assuming you have your data loaded into train_loader
for epoch in range(5):
    for i, (images, labels) in enumerate(train_loader):
        # Flatten the image
        images = images.reshape(-1, 28*28)
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward pass and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    
    print(f'Epoch [{epoch+1}/5], Loss: {loss.item():.4f}')

print("Training finished!")

Epoch [1/5], Loss: 0.0293
Epoch [2/5], Loss: 0.0239
Epoch [3/5], Loss: 0.0365
Epoch [4/5], Loss: 0.0428
Epoch [5/5], Loss: 0.0184
Training finished!


### We now use the exact same principle to test our model

In [22]:
test_dataset = torchvision.datasets.MNIST(root='./data', 
                                          train=False, 
                                          transform=transform)
test_loader = DataLoader(dataset=test_dataset, 
                         batch_size=100, 
                         shuffle=False)

model.eval()

correct = 0
total = 0

with torch.no_grad():
    for images, labels in test_loader:
        images = images.reshape(-1, 28*28)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

# Calculate accuracy
accuracy = 100 * correct / total
print(f'Test Accuracy: {accuracy:.2f}%')

Test Accuracy: 97.61%


# Learn more...

AI is the most researched topic nowdays. Ressources are everywhere, this small and specific introduction s nothing compared to how broad the field is (even tougb it always comes down to the same basic logic).