In [None]:
import torch
import torchvision
import torchvision.transforms as transforms

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Resize(256),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 4

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [None]:
def position_embedding_layer(num_patches, embedding_dimension, batch_size):
    result = torch.ones(num_patches + 1, embedding_dimension)  # Add 1 for the class token
    for i in range(num_patches + 1):  # Add 1 for the class token
        for j in range(embedding_dimension):
            if j % 2 == 0:
                result[i][j] = (np.sin(i / (10000 ** (j / embedding_dimension))))
            else:
                result[i][j] = np.cos(i / (10000 ** ((j - 1) / embedding_dimension)))
    result = result.unsqueeze(0).repeat(3, 1, 1)
    result = result.unsqueeze(0).repeat(batch_size,1,1,1)
    return result

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# functions to show an image


def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()


# get some random training images
dataiter = iter(trainloader)
images, labels = next(dataiter)

# show images
imshow(torchvision.utils.make_grid(images))
# print labels
print(' '.join(f'{classes[labels[j]]:5s}' for j in range(batch_size)))

In [None]:
import torch.nn as nn
import torch.nn.functional as F

def patchify(images, num_of_patches):
    batch_size, channels, height, width = images.shape

    assert height == width, "Patchify method is implemented for square images only"

    height_per_patch = height // num_of_patches
    width_per_patch = width // num_of_patches

    patches = torch.zeros(batch_size, num_of_patches**2, channels, height_per_patch, width_per_patch)

    for idx, image in enumerate(images):
        for i in range(num_of_patches):
            for j in range(num_of_patches):
                patch = image[
                    :,
                    i * height_per_patch : (i + 1) * height_per_patch,
                    j * width_per_patch : (j + 1) * width_per_patch,
                ]
                patches[idx, i * num_of_patches + j] = patch

    return patches

class Projection_Layer(nn.Module):
    def __init__(self, num_patches,patch_size,in_channels, embed_size):
        super(Projection_Layer, self).__init__()


        self.num_patches = num_patches
        self.layer_norm_1 = nn.LayerNorm(patch_size*patch_size)
        self.embed_layer = nn.Linear(patch_size*patch_size,embed_size//3)
        self.layer_norm_2 = nn.LayerNorm(embed_size//3)

    def forward(self, x):
        b,_,c,_,_ = x.shape
        x = x.view(b,c,self.num_patches,-1)
        x = self.layer_norm_1(x)
        x = self.embed_layer(x)
        x = self.layer_norm_2(x)

        return x

class Attention(nn.Module):
    def __init__(self,num_heads, embed_size):
        super(Attention, self).__init__()

        eff_embed_size = embed_size//num_heads
        self.num_heads = num_heads
        self.Q_matrix = nn.Linear(eff_embed_size,eff_embed_size)
        self.K_matrix = nn.Linear(eff_embed_size,eff_embed_size)
        self.V_matrix = nn.Linear(eff_embed_size,eff_embed_size)

        self.Q_final = nn.Sequential(nn.Linear(3 * eff_embed_size, 3* eff_embed_size),nn.GELU(),nn.Linear(3 * eff_embed_size, 3* eff_embed_size))
        self.K_final = nn.Sequential(nn.Linear(3 * eff_embed_size, 3* eff_embed_size),nn.GELU(),nn.Linear(3 * eff_embed_size, 3* eff_embed_size))
        self.V_final = nn.Sequential(nn.Linear(3 * eff_embed_size, 3* eff_embed_size),nn.GELU(),nn.Linear(3 * eff_embed_size, 3* eff_embed_size))

        self.temperature = eff_embed_size**0.5

    def forward(self, x):
        bs,c,n_1,embed_dim = x.shape
        x = x.view(bs,c,self.num_heads,n_1,embed_dim//self.num_heads)  ## B, head, 256, 192//head
        q = self.Q_matrix(x)
        k = self.K_matrix(x)
        v = self.V_matrix(x)

        q1 = q[:,0,:,:]
        q2 = q[:,1,:,:]
        q3 = q[:,2,:,:]

        k1 = k[:,0,:,:]
        k2 = k[:,1,:,:]
        k3 = k[:,2,:,:]

        v1 = v[:,0,:,:]
        v2 = v[:,1,:,:]
        v3 = v[:,2,:,:]

        # Concatenate queries, keys, and values
        q_concat = torch.cat((q1, q2, q3), dim=-1)
        k_concat = torch.cat((k1, k2, k3), dim=-1)
        v_concat = torch.cat((v1, v2, v3), dim=-1)

        q_final = self.Q_final(q_concat)
        k_final = self.K_final(k_concat)
        v_final = self.V_final(v_concat)

        attention = nn.Softmax(dim=-1)(torch.matmul(q_final, k_final.transpose(-1, -2))) / self.temperature
        x = torch.matmul(attention, v_final)
        x = x.view(bs, n_1, embed_dim * 3)
        #x = x.unsqueeze(1).expand(-1, 3, -1, -1)

        return x

class Transformer_Block(nn.Module):
    def __init__(self, num_heads,embed_size,hidden_dim,dropout):
        super(Transformer_Block, self).__init__()

        self.norm = nn.LayerNorm(embed_size)
        self.attn = Attention(num_heads, embed_size)
        self.MLP = nn.Sequential(
            nn.LayerNorm(embed_size),
            nn.Linear(embed_size, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, embed_size),
            nn.Dropout(dropout)
        )
    def forward(self,x):
        b,c,n_1,e = x.shape
        x = self.norm(x)
        x = self.attn(x).view(b,c,n_1,e) + x
        x = self.MLP(x) + x

        return x

class Vision_Transformer(nn.Module):
    def __init__(self, image_size=256, in_channels=3, patch_size=16, embed_size=192, hidden_dim=768, num_heads=8, num_layers=12, dropout=0.01, num_of_patches=4):
        super(Vision_Transformer, self).__init__()

        self.num_patches = (image_size // patch_size) ** 2
        self.embed_size = embed_size
        self.num_heads = num_heads
        self.patch_size = patch_size
        self.in_channels = in_channels

        self.projection_layer = Projection_Layer(self.num_patches, patch_size,in_channels, embed_size)

        self.cls_token = nn.Parameter(torch.randn(1, 1,1, embed_size//3))
        # Using position embedding layer here is assumed
        self.pos_emb = position_embedding_layer(self.num_patches, embed_size//3, 1)  # Using position_embedding_layer

        self.layers = nn.Sequential(*[Transformer_Block(num_heads, embed_size//3, hidden_dim, dropout)
                                      for _ in range(num_layers)])

        self.clf_head = nn.Linear(embed_size, 10)

    def forward(self, x):
        bs, _, _, _ = x.shape
        x = patchify(x, int(self.num_patches ** 0.5))  # Patchify images
        x = self.projection_layer(x)  # Flatten patches and project

        cls_token = self.cls_token.expand(bs,3, -1, -1)  # Broadcasting
        x = torch.cat([cls_token, x], dim=2)

        x = x + self.pos_emb  # Adding position embeddings

        for layer in self.layers:
            x = layer(x)

        x=x.view(bs,self.num_patches+1,-1)
        x = self.clf_head(x[:, 0, :])  # Output classification
        return x

In [None]:
net = Vision_Transformer()

In [None]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=2e-4, momentum=0.9)

In [None]:
from tqdm import tqdm
for epoch in range(20):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in tqdm(enumerate(trainloader, 0)):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data[0], data[1]

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
            running_loss = 0.0

print('Finished Training')

In [None]:
net.eval()

# Initialize variables to track test loss and accuracy
test_loss = 0.0
correct = 0
total = 0

# Disable gradient computation for test phase
with torch.no_grad():
    for i, data in tqdm(enumerate(testloader, 0)):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data[0], data[1]

        # Forward pass
        outputs = net(inputs)

        # Compute the loss
        loss = criterion(outputs, labels)

        # Accumulate the test loss
        test_loss += loss.item()

        # Compute the predicted labels
        _, predicted = torch.max(outputs, 1)

        # Update the total and correct predictions count
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

# Print the test loss and accuracy
print(f'Test Loss: {test_loss / len(testloader):.3f}')
print(f'Test Accuracy: {(100 * correct / total):.2f}%')

# Set the model back to training mode
net.train()
