In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset
import torchvision
from transformer.Vision_transformer import VisionTransformer, CustomDataset
import torchvision.transforms as transforms
from pprint import pprint
from torchsummary import summary
import json
import numpy as np

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# class CustomDataset(Dataset):
#     """Puts incoming MNIST dataset into an object 
#         which can be loaded onto cuda gpu.
#     Parameters
#     ----------
#     data : torchvision.datasets.mnist.MNIST

#     Attributes
#     ----------
#     X : torch.Tensor
#         Shape `(n_samples, n_channels, img_height, img_width)`
#     """
#     def __init__(self, data, device = device):
#         self.X = torch.cat([torch.unsqueeze(data[i][0], dim=0) for i in range(len(data))], dim=0).to(device)
#         self.Y = torch.tensor([data[i][1] for i in range(len(data))]).to(device)
    
#     def __len__(self):
#         """Length method.
#         Parameters
#         ----------
#         None
#         Returns
#         ----------
#         int
#             n_samples

#         """
#         return self.X.shape[0]
    
#     def __getitem__(self, idx):
#         """Indexing call.
#         Parameters:
#         idx : int
#             index of element to be returned.
        
#         Returns : 
#         torch.Tensor
#             Shape `(n_channels, img_height, img_width)`
#         torch.Tensor
#             Shape `(class_idx)`
#         """
#         return self.X[idx], self.Y[idx]


In [None]:
# transform = transforms.Compose([
#     transforms.ToTensor(),
# ])   # Transform object to apply on the dataset.

# train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
# test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
# # Loading/Downloading dataset. `download` can be `False` if the data is present in the root directory
# # Else it will download the dataset to to the root location.

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# # Getting the device to compute on. `cuda` if GPU is available, else `cpu`.

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),  # Convert PIL Image to tensor
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize the image tensors
])

# Load CIFAR-10 training dataset
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)

# Load CIFAR-10 test dataset
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# Classes in CIFAR-10 dataset
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')


In [None]:
train_ds = CustomDataset(data=train_dataset, device=device)
test_ds = CustomDataset(data=test_dataset, device=device)
# Made custom dataset objects from the MNIST dataset.

train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_ds, batch_size=64, shuffle=False)
# DataLoaders for fast implementation of loading batch-wise data.


In [None]:

CIFAR_ViT = VisionTransformer(
    img_size=32,
    patch_size=4,
    in_chans=3,
    n_classes=len(classes),
    embed_dim=16,
    depth=4,
    n_heads=4,
    mlp_ratio=1.0,
    p=0.3,
    attn_p=0.3
).to(device=device)

In [None]:
summary(CIFAR_ViT, (3, 32, 32))

In [None]:

criterion = nn.CrossEntropyLoss()
# Loss criteria for multiclass classification task.
optimizer = torch.optim.Adam(CIFAR_ViT.parameters(), lr=0.001)
# Optimizer to update weights after calculating gradients.

In [None]:
num_epochs = 150
# Number of Epochs to run the following training loop for
for epoch in range(num_epochs):
    CIFAR_ViT.train()
    # Setting the model in training mode
    running_loss = 0.0  
    # Parameter to store the total loss over dataset in the epoch. This has no role in training.
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        #loading images and labels to device. In our case, it is the cuda GPU device.
        
        optimizer.zero_grad()
        outputs = CIFAR_ViT(images)
        # Predicting classes of the input batch.
        loss = criterion(outputs, labels)
        # Calculating loss of the predicted classes with the ground truth
        loss.backward()
        # Backpropagation step
        optimizer.step()
        # Updating the weights according to the optimizer's rules.
        
        running_loss += loss.item()
        # Calculating the loss over the dataset
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}")
    

In [None]:
def test(model, loader = test_loader):
    correct, total = 0, 0
    model.eval()
    # Setting the model in evaluation mode.
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            # Loading batch images and ground truth onto device
            outputs = model(images)
            
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        
    return f"Accuracy on test set: {(100 * correct / total):.2f}%"

In [None]:
test(CIFAR_ViT, test_loader)

In [None]:
torch.save({
    'model_state_dict': CIFAR_ViT.state_dict()
}, "cifar_model_86pct.pth" )