In [4]:
import torch 
import torch.nn as nn  
import matplotlib.pyplot as plt 
import torchvision
import numpy as np 
import torch.nn.functional as F
from einops import rearrange, repeat
from einops.layers.torch import Rearrange

In [5]:
learning_rate = 0.001
weight_decay = 0.0001
num_epochs = 100
image_size = 72  # We'll resize input images to this size
patch_size = 6  # Size of the patches to be extract from the input images
num_patches = (image_size // patch_size) ** 2
projection_dim = 64
num_heads = 4
transformer_units = [
    projection_dim * 2,
    projection_dim,
]  # Size of the transformer layers
transformer_layers = 8
mlp_head_units = [2048, 1024]  # Size of the dense layers of the final classifier
batch_size = 64

In [2]:
num_classes = 100
input_shape = (32,32,3)
import torchvision.transforms as transforms

In [3]:
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data.sampler import SubsetRandomSampler

# number of subprocesses to use for data loading
num_workers = 0
# how many samples per batch to load
batch_size = 5
# percentage of training set to use as validation
valid_size = 0.2

# convert data to a normalized torch.FloatTensor
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    transforms.Resize(image_size)
    ])

# choose the training and test datasets
train_data = datasets.CIFAR10('data', train=True,
                              download=True, transform=transform)
test_data = datasets.CIFAR10('data', train=False,
                             download=True, transform=transform)

# obtain training indices that will be used for validation
num_train = len(train_data)
indices = list(range(num_train))
np.random.shuffle(indices)
split = int(np.floor(valid_size * num_train))
train_idx, valid_idx = indices[split:], indices[:split]

# define samplers for obtaining training and validation batches
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)

# prepare data loaders (combine dataset and sampler)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size,
    sampler=train_sampler, num_workers=num_workers)
valid_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, 
    sampler=valid_sampler, num_workers=num_workers)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, 
    num_workers=num_workers)

# specify the image classes
classes = ['airplane', 'automobile', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck']

Files already downloaded and verified
Files already downloaded and verified


In [10]:
class MLPS(nn.Module): 
    def __init__(self,hidden_units,drop_prob=0.0,hidden_size1=64):
        super(MLPS,self).__init__()
        self.layers = []
        hidden_units.insert(0,hidden_size1)
        for i in range(0,len(hidden_units)-1): 
            self.layers.append(nn.Linear(hidden_units[i],hidden_units[i+1]))
            self.layers.append(nn.Dropout(drop_prob))
        
        self.layers = nn.Sequential(*self.layers)
    
    def forward(self,x): 
        for layer in self.layers:
            x = layer(x)
        return x   


In [11]:
class TransformerBlock(nn.Module):
    def __init__(self,hidden_units,num_heads,projection_dim,dropout=0.0):
        super(TransformerBlock,self).__init__()
        self.hidden_units = hidden_units
        self.attention = nn.MultiheadAttention(projection_dim,num_heads,dropout=dropout)
        self.norm = nn.LayerNorm(projection_dim,eps=1e-6)
        self.mlp = MLPS(hidden_units)

    def forward(self,encoded_patches):
        x1 = encoded_patches
        #add norm layer here
        attention_out, out_weights = self.attention(x1,x1,x1)
        x2 = x1 + x1 #residual concatenation 
        #add norm layer here
        x3 = self.mlp(x2)
        encoded = x3 + x2
        return encoded
        
  

In [12]:
class VIT(nn.Module):
   def __init__(self,transformer=None,image_size=image_size,patch_size=patch_size,num_classes=num_classes,projection_dim=projection_dim,num_patches=num_patches,
   hidden_units=transformer_units, num_heads = num_heads,mlp_head_units=mlp_head_units, drop_prob=0.0):
      super(VIT,self).__init__()
      assert image_size % patch_size == 0, "image size must be dividible by patch size" 
      self.num_patches =  num_patches
      self.patch_dim = 3 * patch_size**2
   
      self.flatten_patches = Rearrange('b c (h px1) (w px2) -> b (h w) (px1 px2 c)', px1 = patch_size, px2 = patch_size)
      self.patch_emedding = nn.Linear(self.patch_dim,projection_dim)

      self.position_embedding = nn.Parameter(torch.randn(1, num_patches + 1, projection_dim))   #pose embedding 
      self.class_token = nn.Parameter(torch.randn(1, 1, projection_dim))   #class embedding 
     
      self.transformer = TransformerBlock(hidden_units,num_heads,projection_dim,drop_prob)
      self.pose_embedding = nn.Embedding(num_patches+1,projection_dim)
      self.projection = nn.Linear(108,projection_dim)
      #output
      self.mlp_head = MLPS(mlp_head_units,drop_prob=0.5,hidden_size1=9216)

      self.to_latent = nn.Linear(mlp_head_units[-1],num_classes)
      self.flatten = nn.Flatten()
      self.drop = nn.Dropout(drop_prob)
      self.softmax = nn.Softmax(dim=0)

   
   def forward(self,img):
      #"pre process"
      """
      x = self.flatten_patches(x)
      x = self.patch_emedding(x)
      """
      patch = self.flatten_patches(img)
      #patch encoding 
      positions = torch.arange(0,self.num_patches)
      encoded1 = self.pose_embedding(positions)
      encoded2 = self.projection(patch)
      encoded = encoded1 + encoded2
      block = self.transformer(encoded)
      block_norm = block #add norm here
      representation = self.drop(self.flatten(block_norm))
      features = self.mlp_head(representation)
      logits = self.to_latent(features)
      return logits 
   

In [14]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(),lr=learning_rate)


In [16]:
device = "cuda" if torch.cuda.is_available() else "cpu" 

if device == "cuda":
    model = model.to(device)

val_loss_min = float("Inf")
for i in range(0,num_epochs):
    train_loss,val_loss = 0.,0. 
    model.train()

    for img, label in train_loader:
        if device == "cuda":
            img, label = img.to(device), label.to(device)
        optimizer.zero_grad()
        out = model(img)
        loss = criterion(out,label)
        loss.backward()
        optimizer.step() 
        train_loss += loss.item() * batch_size
        print(train_loss)
    #validation
    
    model.eval()
    for img, label in val_loader:
        if device == "cuda":
            img, label = img.to(device), label.to(device)
        
        out = model(img)
        loss = criterion(out,label)
        val_loss += loss.item() * batch_size
    
    val_loss = val_loss/len(val_loader.dataset)
    train_loss = train_loss/len(train_loader.dataset)

    # print training/validation statistics 
    print('Epoch: {} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.format(
        epoch, train_loss, val_loss))
    
    # save model if validation loss has decreased
    if valid_loss <= valid_loss_min:
        print('Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(
        val_loss_min,
        val_loss))
        torch.save(model.state_dict(), 'model_cifar.pt')
        val_loss_min = val_loss
    



25.097453594207764
378.1101107597351
833.0985903739929
1527.257068157196
2234.0160250663757
3477.58642911911
4844.256808757782
5661.485583782196
7157.842547893524
8843.365466594696
9598.342578411102
11015.31965970993
12370.721118450165
12923.319079875946
14260.325152873993
14727.783954143524
15695.101459026337
16672.40446805954
18322.389056682587
19063.379123210907
19501.151111125946
19724.660193920135
20229.36669111252
20588.683936595917
21609.88701581955
21989.82338666916
22639.956290721893
23181.625678539276
23448.797652721405
24064.12155866623
24509.9236369133
24783.65441083908
24966.737945079803
26076.55277967453
26622.13341474533
27141.626212596893
27797.334525585175
28760.880210399628
29136.16329908371
29620.96054792404
29910.05966901779
30261.666839122772
30581.55081510544
30921.21092557907
31588.05144071579
32522.729785442352
33655.37001371384
34073.66751432419
34238.18985700607
34839.27078962326
35069.88425016403
35594.79922056198
36022.68233060837
36164.25070524216
36747.819

KeyboardInterrupt: 

In [64]:
target = torch.tensor([1, 3, 5, 9, 4])
criterion = nn.CrossEntropyLoss()

out = torch.Tensor([[ 0.0336,  0.2922, -0.2848,  1.5458, -0.0692, -0.0332, -0.4918,  1.7504,
         -1.0620, -0.6737],
        [ 0.0451,  0.0415, -0.0913,  1.0814, -0.6469,  1.1928, -1.3046, -0.3185,
         -1.6709, -0.9702],
        [-0.6552, -0.5635,  1.4852,  0.8358, -0.5598,  1.9645, -0.7523,  0.5076,
          0.6408,  0.7347],
        [-0.8069, -0.3164,  0.6914, -0.1573, -1.8538,  0.5686, -0.2744,  1.5524,
         -0.4634,  1.1055],
        [ 0.2714, -1.4807,  1.1702, -0.3368, -1.1552, -0.0126, -0.3035,  0.1973,
         -0.0620,  0.5993]])

loss = criterion(out,target)