In [1]:
import torch
import torch.nn as nn
import torchvision
from transformer.Vision_transformer import VisionTransformer, CustomDataset
import torchvision.transforms as transforms
from pprint import pprint
from torchsummary import summary
import json

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
transform = transforms.Compose([
    transforms.ToTensor(),
])   # Transform object to apply on the dataset.

train_dataset = torchvision.datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)

# train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
# test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
# Loading/Downloading dataset. `download` can be `False` if the data is present in the root directory
# Else it will download the dataset to to the root location.

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Getting the device to compute on. `cuda` if GPU is available, else `cpu`.

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./data\FashionMNIST\raw\train-images-idx3-ubyte.gz


100%|██████████| 26421880/26421880 [01:06<00:00, 394569.57it/s]


Extracting ./data\FashionMNIST\raw\train-images-idx3-ubyte.gz to ./data\FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./data\FashionMNIST\raw\train-labels-idx1-ubyte.gz


100%|██████████| 29515/29515 [00:00<00:00, 101217.67it/s]


Extracting ./data\FashionMNIST\raw\train-labels-idx1-ubyte.gz to ./data\FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./data\FashionMNIST\raw\t10k-images-idx3-ubyte.gz


100%|██████████| 4422102/4422102 [00:07<00:00, 595327.06it/s]


Extracting ./data\FashionMNIST\raw\t10k-images-idx3-ubyte.gz to ./data\FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./data\FashionMNIST\raw\t10k-labels-idx1-ubyte.gz


100%|██████████| 5148/5148 [00:00<?, ?it/s]


Extracting ./data\FashionMNIST\raw\t10k-labels-idx1-ubyte.gz to ./data\FashionMNIST\raw



In [4]:
with open('transformer/config.json') as f:
    custom_config = json.load(f)
# Custom configurations for the VisionTransformer.
# Transformer can be customized with these configurations.
# Refer to documentation of the class VisionTransformer
# (`VisionTransformer.__doc__`, use pprint for cleaner display)
# for exact details of the customization.


In [5]:
MNIST_ViT = VisionTransformer(**custom_config).to(device=device)
# VisionTransformer object from the custom configuration

In [6]:
summary(MNIST_ViT, (1, 28, 28))
# Prints summary of each layer in the transformer for the given input shape

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1             [-1, 32, 7, 7]             544
        PatchEmbed-2               [-1, 49, 32]               0
           Dropout-3               [-1, 50, 32]               0
         LayerNorm-4               [-1, 50, 32]              64
            Linear-5               [-1, 50, 96]           3,168
           Dropout-6            [-1, 2, 50, 50]               0
            Linear-7               [-1, 50, 32]           1,056
           Dropout-8               [-1, 50, 32]               0
         Attention-9               [-1, 50, 32]               0
        LayerNorm-10               [-1, 50, 32]              64
           Linear-11               [-1, 50, 12]             396
             GELU-12               [-1, 50, 12]               0
          Dropout-13               [-1, 50, 12]               0
           Linear-14               [-1,

In [7]:
train_ds = CustomDataset(data=train_dataset)
test_ds = CustomDataset(data=test_dataset)
# Made custom dataset objects from the MNIST dataset.

train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_ds, batch_size=64, shuffle=False)
# DataLoaders for fast implementation of loading batch-wise data.

criterion = nn.CrossEntropyLoss()
# Loss criteria for multiclass classification task.
optimizer = torch.optim.Adam(MNIST_ViT.parameters(), lr=0.001)
# Optimizer to update weights after calculating gradients.


In [8]:
num_epochs = 30
# Number of Epochs to run the following training loop for
for epoch in range(num_epochs):
    MNIST_ViT.train()
    # Setting the model in training mode
    running_loss = 0.0  
    # Parameter to store the total loss over dataset in the epoch. This has no role in training.
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        #loading images and labels to device. In our case, it is the cuda GPU device.
        
        optimizer.zero_grad()
        outputs = MNIST_ViT(images)
        # Predicting classes of the input batch.
        loss = criterion(outputs, labels)
        # Calculating loss of the predicted classes with the ground truth
        loss.backward()
        # Backpropagation step
        optimizer.step()
        # Updating the weights according to the optimizer's rules.
        
        running_loss += loss.item()
        # Calculating the loss over the dataset
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}")
    

Epoch [1/30], Loss: 1.0259
Epoch [2/30], Loss: 0.6643
Epoch [3/30], Loss: 0.6050
Epoch [4/30], Loss: 0.5691
Epoch [5/30], Loss: 0.5491
Epoch [6/30], Loss: 0.5348
Epoch [7/30], Loss: 0.5187
Epoch [8/30], Loss: 0.5114
Epoch [9/30], Loss: 0.5046
Epoch [10/30], Loss: 0.4975
Epoch [11/30], Loss: 0.4912
Epoch [12/30], Loss: 0.4847
Epoch [13/30], Loss: 0.4850
Epoch [14/30], Loss: 0.4772
Epoch [15/30], Loss: 0.4721
Epoch [16/30], Loss: 0.4747
Epoch [17/30], Loss: 0.4707
Epoch [18/30], Loss: 0.4687
Epoch [19/30], Loss: 0.4652
Epoch [20/30], Loss: 0.4601
Epoch [21/30], Loss: 0.4587
Epoch [22/30], Loss: 0.4557
Epoch [23/30], Loss: 0.4549
Epoch [24/30], Loss: 0.4498
Epoch [25/30], Loss: 0.4486
Epoch [26/30], Loss: 0.4459
Epoch [27/30], Loss: 0.4436
Epoch [28/30], Loss: 0.4411
Epoch [29/30], Loss: 0.4412
Epoch [30/30], Loss: 0.4424


In [12]:
MNIST_ViT.eval()
# Setting the model in evaluation mode.
correct, total = 0, 0
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        # Loading batch images and ground truth onto device
        outputs = MNIST_ViT(images)
        # Calculating logits.
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        # Updated number of correct predictions and total predictions.

print(f"Accuracy on test set: {(100 * correct / total):.2f}%")

Accuracy on test set: 83.60%


In [13]:
torch.save({
    'model_state_dict': MNIST_ViT.state_dict()
}, "fashionMNIST.pth" )
# Saving the trained model

: 