In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
import numpy as np

In [None]:
# release the GPU memory
torch.cuda.empty_cache()

In [None]:
# Check if GPU is available
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("GPU is available for computation.")
else:
    device = torch.device("cpu")
    print("GPU is not available.")

In [None]:
# load CIFAR-10 dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False, num_workers=2)


In [None]:
print("Number of training samples: ", len(trainset))


print(trainloader.dataset.data.shape) 

In [None]:
class LayerNormMLP(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(LayerNormMLP, self).__init__()
        self.linear1 = nn.Linear(input_dim, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, input_dim)
        self.layer_norm = nn.LayerNorm(input_dim)
    
    def forward(self, x):
        mlp_output = self.linear2(F.relu(self.linear1(x)))
        print("mlp_output (LayerNorm): ", mlp_output.size())
        residual = mlp_output + x
        print("residual(LayerNorm): ", residual.size())
        normalized_output = self.layer_norm(residual)
        print("normalized_output(LayerNorm): ", normalized_output.size())
        return normalized_output

In [None]:

class SAttention(nn.Module):
    def __init__(self, input_dim,heads):
        super(SAttention, self).__init__()
        self.query = nn.Linear(input_dim, heads)
        self.key = nn.Linear(input_dim, heads)
        self.value = nn.Linear(input_dim, heads)
        # self.W_o = nn.Linear(input_dim, heads)
    
    def forward(self, x):
        # batch_size , num_patches , input_dim = x.shape
        q = self.query(x)
        k = self.key(x)
        v = self.value(x)
        att = self.scaled_dot_product_attention(q, k, v)
        print("att:",att.size())
        # output = self.W_o(att)
        # print("output:",output.size())
        return att
    
    def scaled_dot_product_attention(self, q, k, v):
        d_k = q.size(-1)
        scores = torch.matmul(q, k.transpose(-2, -1)) / np.sqrt(d_k)
        print("scores:",scores.size())
        
        attention = F.softmax(scores, dim=-1)
        print("attention:",attention.size())
        
        energy = torch.matmul(attention, v)
        print("energy:",energy.size())
        return energy


In [None]:

class MultiHeadAttention(nn.Module):
    def __init__(self, input_dim, heads):
        super(MultiHeadAttention, self).__init__()
        self.head_dim = input_dim // heads
        self.attention_heads = nn.ModuleList([SAttention(input_dim = input_dim , heads = self.head_dim ) for _ in range(heads)])
        self.W_o = nn.Linear(input*heads, input)
    
    def forward(self, x):
        attention = [attention_head(x) for attention_head in self.attention_heads]
        attention = torch.cat(attention, dim=-1)
        print("attention cat:",attention.size())
        output = self.W_o(attention)
        return output

In [None]:
class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dim,out_dim):
        super(MLP, self).__init__()
        self.linear1 = nn.Linear(input_dim, hidden_dim)
        self.activation = nn.ReLU()
        self.linear2 = nn.Linear(hidden_dim, input_dim)
    
    def forward(self, x):
        x = self.activation(self.linear1(x))
        mlp_output = self.linear2(x)
        return mlp_output

In [None]:

class TransformerLayer(nn.Module):
    def __init__(self, input_dim, num_heads,mlp_ratio):
        super(TransformerLayer, self).__init__()
        self.layer_norm1 = nn.LayerNorm(input)
        self.attention = MultiHeadAttention(input_dim, num_heads)
        self.layer_norm2 = nn.LayerNorm(input)
        hidden_feat = int(input_dim * mlp_ratio)
        self.mlp = MLP(input_dim, hidden_feat, input_dim)
    
    def forward(self, x):
        x = x + self.attention(self.layer_norm1(x))
        x = x + self.mlp(self.layer_norm2(x))
        return x

In [None]:
class VisionTransformer(nn.Module):
    def __init__(self,layers):
        super(VisionTransformer, self).__init__()
        self.patch_embed = nn.Conv2d(3, 64, kernel_size=16, stride=16)
        self.class_token = nn.Parameter(torch.randn(1, self.patch_embed.out_channels, 1))
        self.transformer = nn.ModuleList([TransformerLayer(input_dim = 64, num_heads = 8, mlp_ratio = 4) for _ in range(layers)])
        self.pos_enc = nn.Parameter(torch.randn(1, self.patch_embed.out_channels, 1))
        self.layer_norm = nn.LayerNorm(64)
        self.linear = nn.Linear(64, 10)
        
    def forward(self, x):
        x = self.patch_embed(x)
        x = x + self.class_token
        x = x + self.pos_enc
        x = self.layer_norm(x)
        for transformer in self.transformer:
            x = transformer(x)
        x = x.mean(dim=1)
        x = self.linear(x)
        return x
        

In [None]:
model = VisionTransformer(layers = 12)
model.to(device)

In [None]:
# define the loss function and optimizer

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)


In [None]:


model.train()  # Set the model to training mode

for epoch in range(10):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        
        inputs = data[0]
        labels = data[1]
        inputs.to(device)
        labels.to(device)
        print("inputs: ", inputs.size())
        print("labels: ", labels.size())
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(inputs)
        
        # Calculate loss
        loss = criterion(outputs, labels)
        
        # Backpropagation and optimization
        loss.backward()
        optimizer.step()
        
        # Print statistics
        running_loss += loss.item()
        if i % 100 == 99:
            print(f"Epoch: {epoch+1}, Batch: {i+1}, Loss: {running_loss/100}")
            running_loss = 0.0
