# PyTorch Computer Vision

This notebook closely follows the material available at learnpytorch.io [[1]](https://www.learnpytorch.io/) with occasional refactoring and extension for consistency of style and to make connections with other parts of the package. It is also more extensive on examples and does less revisit of lower level concepts once discussed.

### Computer Vision

Computer vision [[2]](https://en.wikipedia.org/wiki/Computer_vision) is concerned with the extraction of information from images. It encompasses multiple sub-domains, we give a few examples below.

* Object detection. Deals with detection of instances of objects of a certain class, like humans, buildings, or cars.
* Motion estimation. The process of determining motion vectors that describe the transformation from one 2D image to another.
* Image segmentation. Concerned with partitioning of an image into multiple regions based on some homogeneity criteria.

In essence, anything that can be described in a visual sense can be a potential computer vision problem.

In [None]:
import matplotlib.pyplot as plt
import torch
import torchvision
from torch import nn

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

### Sample data - Fashion-MNIST

The MNIST database [[3]](https://en.wikipedia.org/wiki/MNIST_database) consists of handwritten digits used in training various image processing systems and consequently finds its usages in machine learning. Fashion-MNIST [[4]](https://github.com/zalandoresearch/fashion-mnist) is a fresh take on this database and comes with identical structure except for the fact that instead of digits in contains clothing images falling into 10 different classes. The Fashion-MNIST dataset will serve as our sample data.

In [None]:
from torchvision import datasets
from torchvision.transforms import ToTensor


train_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
    target_transform=None,
)
test_data = datasets.FashionMNIST(
    root="data",
    train=False, # get test data
    download=True,
    transform=ToTensor()
)

The resulting datasets are `FashionMNIST` objects and contain, in a sequential format, tensor, by courtesy of the chosen `ToTensor` transformation, and label pairs, the latter referring to the position in the `classes` attribute of the dataset. Let us break up the data exploration into visible and easy-to-digest chunks; we start with the list of clothing types.

In [None]:
print(train_data.classes)

In [None]:
image, label = train_data[0]
print(image.shape)

The shape of image tensors is 1x28x28 which translates to the color channel, the height, and the width, respectively. Both size dimensions are pixels, the single dimensional color channel means that the image is greyscale.

**Note.** This order is known as CHW. This abbreviation is commonly extended with an N prefix to form NCHW with N standing for the number of images. PyTorch generally accepts NCHW, but performs better on NHWC, a channel last format. Here, it will not make measureable difference. See [[5]](https://pytorch.org/blog/tensor-memory-format-matters/#pytorch-best-practice) for more details and we refer to [[6]](https://en.wikipedia.org/wiki/Color_model) as well for various color models.

In [None]:
print(train_data.classes[label])
print(image)

We may also visualize the data for several randomly picked image.

In [None]:
torch.manual_seed(42)
figure = plt.figure(figsize=(12, 8))
nrows, ncols = 4, 4
for i in range(1, nrows * ncols + 1):
    index = torch.randint(0, len(train_data), size=[1]).item()
    image, label = train_data[index]
    figure.add_subplot(nrows, ncols, i)
    plt.imshow(image.squeeze(), cmap="gray")
    plt.title(train_data.classes[label])
    plt.axis(False);
plt.show()

### Detour - batch processing

PyTorch is equipped with a powerful utility called `DataLoader` that turns datasets into a Python iterable. In an ideal world, the forward and backward passes across the model can be done all at once. In practice, we are limited by computational efficiency, and it is more performant to work in batches of the data. An important side effect is that gradient descent is also performed more often per epoch; once for each batch. Batch size is yet another hyperparameter of models.

In [None]:
from typing import Iterable
from torch.utils.data import DataLoader

BATCH_SIZE = 32

train_dataloader = DataLoader(
    train_data,
    batch_size=BATCH_SIZE,
    shuffle=True,
)
test_dataloader = DataLoader(
    test_data,
    batch_size=BATCH_SIZE,
    shuffle=False,
)
print(isinstance(train_dataloader, Iterable))
print(isinstance(test_dataloader, Iterable))

### Linear model with flattening

Recall that the problem of identifying clothing types correctly is essentially a multiclass classification problem. Our initial approach will be using linear layers with hidden neurons. However, our data for a single image is a proper third order tensor, and linear modeling requires a first order tensor, or simply, a vector format. PyTorch offers the `Flatten` layer for preparing data to be fed into a linear layer.

In [None]:
class FashionMNISTLinear(nn.Module):
    
    def __init__(
        self,
        input_shape,
        hidden_units,
        output_shape,
    ):
        super().__init__()
        
        self.layer_stack = nn.Sequential(
            nn.Flatten(),
            nn.Linear(
                in_features=input_shape,
                out_features=hidden_units
            ),
            nn.Linear(
                in_features=hidden_units,
                out_features=output_shape
            ),
        )
    
    def forward(self, x):
        return self.layer_stack(x)

In [None]:
model = FashionMNISTLinear(
    28*28,
    10,
    10,
).to(DEVICE)

As mentioned in the other notebook, the optimizer is problem-agnostic. For the loss function, we pick the generic version of cross entropy [[7]](https://en.wikipedia.org/wiki/Cross_entropy). We will also use a multiclass problem specific accuracy function from `torchmetrics` [[8]](https://torchmetrics.readthedocs.io/en/stable/classification/accuracy.html#multiclassaccuracy).

In [None]:
from torchmetrics import Accuracy

loss_function = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(params=model.parameters(), lr=0.1)
accuracy_function = Accuracy(task="multiclass", num_classes=10).to(DEVICE)

In [None]:
torch.manual_seed(42)

epoch_count = 3

for epoch in range(epoch_count):
    print(f"Epoch: {epoch}\n--------")
    
    train_loss = 0
    
    for batch, (train_input, train_output) in enumerate(train_dataloader):
        train_input, train_output = train_input.to(DEVICE), train_output.to(DEVICE)
        
        model.train()
        train_classification = model(train_input)
        
        loss = loss_function(train_classification, train_output)
        train_loss += loss
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if batch % 400 == 0:
            print(f"Looked at {batch * len(train_input)}/{len(train_dataloader.dataset)} samples")
        
    train_loss /= len(train_dataloader)
    
    test_loss, test_accuracy = 0, 0
    
    model.eval()
    
    with torch.inference_mode():
        for test_input, test_output in test_dataloader:
            test_input, test_output = test_input.to(DEVICE), test_output.to(DEVICE)
            
            test_classification = model(test_input)
            
            test_loss += loss_function(test_classification, test_output)
            
            test_accuracy += accuracy_function(test_output, test_classification.argmax(dim=1))
        
        test_loss /= len(test_dataloader)
        
        test_accuracy /= len(test_dataloader)
    
    print(f"\nTrain loss: {train_loss:.5f} | Test loss: {test_loss:.5f}, Test acc: {test_accuracy:.2f}%\n")

### References

[1] Learn PyTorch for Deep Learning: Zero to Mastery book, accessed online on 2023.05.01 at https://www.learnpytorch.io/

[2] Computer vision, Wikipedia article, accesssed online on 2023.05.01 at https://en.wikipedia.org/wiki/Computer_vision

[3] MNIST database, Wikipedia article, accessed online on 2023.05.01 at https://en.wikipedia.org/wiki/MNIST_database

[4] Zalando Research, Fashion MNIST database, accessed online on 2023.05.01 at https://github.com/zalandoresearch/fashion-mnist

[5] Efficient PyTorch: Tensor Memory Format Matters, blog entry, accessed online on 2023.05.01 at https://pytorch.org/blog/tensor-memory-format-matters/#pytorch-best-practice

[6] Color model, Wikipedia article, accessed online on 2023.05.01 at https://en.wikipedia.org/wiki/Color_model

[7] Cross entropy, Wikipedia article, accessed online on 2023.05.01 at https://en.wikipedia.org/wiki/Cross_entropy

[8] TorchMetrics, API documentation, accessed online on 2023.05.01 at https://torchmetrics.readthedocs.io/en/stable/classification/accuracy.html#multiclassaccuracy