In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset
import torchvision
from transformer.Vision_transformer import VisionTransformer
import torchvision.transforms as transforms
from pprint import pprint
from torchsummary import summary
import json

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
class CustomDataset(Dataset):
    """Puts incoming MNIST dataset into an object 
        which can be loaded onto cuda gpu.
    Parameters
    ----------
    data : torchvision.datasets.mnist.MNIST

    Attributes
    ----------
    X : torch.Tensor
        Shape `(n_samples, n_channels, img_height, img_width)`
    """
    def __init__(self, data, device = device):
        self.X = torch.cat([torch.unsqueeze(data[i][0], dim=0) for i in range(len(data))], dim=0).to(device)
        self.Y = torch.tensor([data[i][1] for i in range(len(data))]).to(device)
    
    def __len__(self):
        """Length method.
        Parameters
        ----------
        None
        Returns
        ----------
        int
            n_samples

        """
        return self.X.shape[0]
    
    def __getitem__(self, idx):
        """Indexing call.
        Parameters:
        idx : int
            index of element to be returned.
        
        Returns : 
        torch.Tensor
            Shape `(n_channels, img_height, img_width)`
        torch.Tensor
            Shape `(class_idx)`
        """
        return self.X[idx], self.Y[idx]


In [4]:
transform = transforms.Compose([
    transforms.ToTensor(),
])   # Transform object to apply on the dataset.

train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
# Loading/Downloading dataset. `download` can be `False` if the data is present in the root directory
# Else it will download the dataset to to the root location.

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Getting the device to compute on. `cuda` if GPU is available, else `cpu`.

In [5]:
transform = transforms.Compose([
    transforms.ToTensor(),  # Convert PIL Image to tensor
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize the image tensors
])

# Load CIFAR-10 training dataset
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)

# Load CIFAR-10 test dataset
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# Classes in CIFAR-10 dataset
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')


Files already downloaded and verified
Files already downloaded and verified


In [6]:
train_ds = CustomDataset(data=train_dataset)
test_ds = CustomDataset(data=test_dataset)
# Made custom dataset objects from the MNIST dataset.

train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_ds, batch_size=64, shuffle=False)
# DataLoaders for fast implementation of loading batch-wise data.


In [15]:
CIFAR_ViT = VisionTransformer(
    img_size=32,
    patch_size=8,
    in_chans=3,
    n_classes=len(classes),
    embed_dim=128,
    depth=4,
    n_heads=4,
    mlp_ratio=0.5,
    p=0.3,
    attn_p=0.3
).to(device=device)

In [16]:
summary(CIFAR_ViT, (3, 32, 32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1            [-1, 128, 4, 4]          24,704
        PatchEmbed-2              [-1, 16, 128]               0
           Dropout-3              [-1, 17, 128]               0
         LayerNorm-4              [-1, 17, 128]             256
            Linear-5              [-1, 17, 384]          49,536
           Dropout-6            [-1, 4, 17, 17]               0
            Linear-7              [-1, 17, 128]          16,512
           Dropout-8              [-1, 17, 128]               0
         Attention-9              [-1, 17, 128]               0
        LayerNorm-10              [-1, 17, 128]             256
           Linear-11               [-1, 17, 64]           8,256
             GELU-12               [-1, 17, 64]               0
          Dropout-13               [-1, 17, 64]               0
           Linear-14              [-1, 

In [17]:

criterion = nn.CrossEntropyLoss()
# Loss criteria for multiclass classification task.
optimizer = torch.optim.Adam(CIFAR_ViT.parameters(), lr=0.001)
# Optimizer to update weights after calculating gradients.

In [18]:
num_epochs = 150
# Number of Epochs to run the following training loop for
for epoch in range(num_epochs):
    CIFAR_ViT.train()
    # Setting the model in training mode
    running_loss = 0.0  
    # Parameter to store the total loss over dataset in the epoch. This has no role in training.
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        #loading images and labels to device. In our case, it is the cuda GPU device.
        
        optimizer.zero_grad()
        outputs = CIFAR_ViT(images)
        # Predicting classes of the input batch.
        loss = criterion(outputs, labels)
        # Calculating loss of the predicted classes with the ground truth
        loss.backward()
        # Backpropagation step
        optimizer.step()
        # Updating the weights according to the optimizer's rules.
        
        running_loss += loss.item()
        # Calculating the loss over the dataset
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}")
    

Epoch [1/150], Loss: 1.8505
Epoch [2/150], Loss: 1.6677
Epoch [3/150], Loss: 1.5761
Epoch [4/150], Loss: 1.5164
Epoch [5/150], Loss: 1.4654
Epoch [6/150], Loss: 1.4317
Epoch [7/150], Loss: 1.3975
Epoch [8/150], Loss: 1.3718
Epoch [9/150], Loss: 1.3386
Epoch [10/150], Loss: 1.3196
Epoch [11/150], Loss: 1.2947
Epoch [12/150], Loss: 1.2717
Epoch [13/150], Loss: 1.2457
Epoch [14/150], Loss: 1.2294
Epoch [15/150], Loss: 1.2054
Epoch [16/150], Loss: 1.1933
Epoch [17/150], Loss: 1.1777
Epoch [18/150], Loss: 1.1518
Epoch [19/150], Loss: 1.1354
Epoch [20/150], Loss: 1.1202
Epoch [21/150], Loss: 1.1075
Epoch [22/150], Loss: 1.0885
Epoch [23/150], Loss: 1.0814
Epoch [24/150], Loss: 1.0680
Epoch [25/150], Loss: 1.0547
Epoch [26/150], Loss: 1.0392
Epoch [27/150], Loss: 1.0359
Epoch [28/150], Loss: 1.0177
Epoch [29/150], Loss: 1.0112
Epoch [30/150], Loss: 1.0030
Epoch [31/150], Loss: 0.9932
Epoch [32/150], Loss: 0.9853
Epoch [33/150], Loss: 0.9831
Epoch [34/150], Loss: 0.9732
Epoch [35/150], Loss: 0

In [19]:
def test(model, loader = test_loader):
    correct, total = 0, 0
    model.eval()
    # Setting the model in evaluation mode.
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            # Loading batch images and ground truth onto device
            outputs = model(images)
            
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        
    return f"Accuracy on test set: {(100 * correct / total):.2f}%"

In [23]:
test(CIFAR_ViT, test_loader)

'Accuracy on test set: 68.14%'

In [22]:
torch.save({
    'model_state_dict': CIFAR_ViT.state_dict()
}, "cifar_model_86pct.pth" )