# Image Classification with PyTorch

In [None]:
import torch
from torchvision import datasets, transforms

In [None]:
%matplotlib inline
from matplotlib import pyplot as plt

### Load and re-scale input data

In [None]:
transformer=transforms.Compose([
    transforms.ToTensor(),
])

In [None]:
train_dataset = datasets.FashionMNIST('./data', train=True, download=True, transform=transformer)
test_dataset = datasets.FashionMNIST('./data', train=False, transform=transformer)

In [None]:
trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)
testloader = torch.utils.data.DataLoader(test_dataset, batch_size=test_dataset.data.shape[0])

In [None]:
classes = {v: k for k, v in train_dataset.class_to_idx.items()}

In [None]:
n = 6
plt.figure(figsize=(20, 4))
for i in range(n):
    # display original
    ax = plt.subplot(1, n, i + 1)
    plt.imshow(test_dataset[i][0][0])
    plt.title(classes[test_dataset[i][1]])
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

plt.show()

### Build model

In [None]:
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(28*28, 128)
        self.dropout = nn.Dropout(0.2)
        self.fc2 = nn.Linear(128,10)

    def forward(self, x):
        x = self.flatten(x)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

In [None]:
model = Model()

### Train the network 

In [None]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())

In [None]:
from torchmetrics.classification import MulticlassAccuracy

accuracy = MulticlassAccuracy(num_classes=len(train_dataset.classes))

for epoch in range(20):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 200 == 199:    # print every 2000 mini-batches
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')                        
            running_loss = 0.0

    # Evaluate accuracy
    for inputs, labels in testloader:
        preds = model(inputs)
        print(f"Accuracy on validation set: {float(accuracy(preds, labels))}")