In [1]:
import matplotlib.pyplot as plt 
import torch
import torch.nn as nn 
import torch.optim as optim
from torchvision import transforms
import torchvision.models as models 
import torchvision
from torchinfo import summary

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

### Get the pretrained ViT model with weights and biases.

In [46]:
model = models.vit_b_16(
    weights = models.ViT_B_16_Weights.IMAGENET1K_V1
)

In [47]:
summary(
    model = model, 
    input_size = (32 , 3 , 224 , 224)
)

Layer (type:depth-idx)                        Output Shape              Param #
VisionTransformer                             [32, 1000]                768
├─Conv2d: 1-1                                 [32, 768, 14, 14]         590,592
├─Encoder: 1-2                                [32, 197, 768]            151,296
│    └─Dropout: 2-1                           [32, 197, 768]            --
│    └─Sequential: 2-2                        [32, 197, 768]            --
│    │    └─EncoderBlock: 3-1                 [32, 197, 768]            7,087,872
│    │    └─EncoderBlock: 3-2                 [32, 197, 768]            7,087,872
│    │    └─EncoderBlock: 3-3                 [32, 197, 768]            7,087,872
│    │    └─EncoderBlock: 3-4                 [32, 197, 768]            7,087,872
│    │    └─EncoderBlock: 3-5                 [32, 197, 768]            7,087,872
│    │    └─EncoderBlock: 3-6                 [32, 197, 768]            7,087,872
│    │    └─EncoderBlock: 3-7             

In [48]:
model

VisionTransformer(
  (conv_proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  (encoder): Encoder(
    (dropout): Dropout(p=0.0, inplace=False)
    (layers): Sequential(
      (encoder_layer_0): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (dropout): Dropout(p=0.0, inplace=False)
        (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=3072, out_features=768, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
      (encoder_layer_1): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_a

#### ViT has 3 main parts:
   1. conv_proj : patch embedding part.
   2. encoder block : Transformer block and we have total 12 transformer block stacked on each other.
   3. heads : heads is the classifier block and inside head we have a sequencial fully connected block, the main classifier block. 

In [7]:
model.conv_proj

Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))

In [8]:
# get the encoder part
model.encoder

Encoder(
  (dropout): Dropout(p=0.0, inplace=False)
  (layers): Sequential(
    (encoder_layer_0): EncoderBlock(
      (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (self_attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
      )
      (dropout): Dropout(p=0.0, inplace=False)
      (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): MLPBlock(
        (0): Linear(in_features=768, out_features=3072, bias=True)
        (1): GELU(approximate='none')
        (2): Dropout(p=0.0, inplace=False)
        (3): Linear(in_features=3072, out_features=768, bias=True)
        (4): Dropout(p=0.0, inplace=False)
      )
    )
    (encoder_layer_1): EncoderBlock(
      (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (self_attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
      )
    

In [10]:
# get the head 
model.heads.head

Linear(in_features=768, out_features=1000, bias=True)

### Make the dataset and dataloader

In [11]:
NUM_WORKER = 4
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [12]:
train_data_path = "Data/Flowers/train"
test_data_path = "Data/Flowers/test"
validation_data_path = "Data/Flowers/valid"

In [14]:
from torchvision import datasets
from torch.utils.data import DataLoader
BATCH_SIZE = 32
train_dataset = datasets.ImageFolder(
    root = train_data_path , transform = preprocess
)
test_dataset = datasets.ImageFolder(root = test_data_path , transform = preprocess)
valid_dataset = datasets.ImageFolder(root = validation_data_path , transform = preprocess)

In [15]:
class_names = train_dataset.classes

In [16]:
class_names

['daisy', 'dandelion', 'rose', 'sunflower', 'tulip']

In [17]:
# dataloader 
train_loader = DataLoader(
    dataset = train_dataset, 
    batch_size= BATCH_SIZE, 
    num_workers = NUM_WORKER,
    pin_memory= True,
    shuffle = True
)

test_loader = DataLoader(
    dataset = test_dataset, 
    batch_size= BATCH_SIZE, 
    num_workers = NUM_WORKER,
    pin_memory= True,
    shuffle = False
)

val_loader = DataLoader(
    dataset = valid_dataset, 
    batch_size= BATCH_SIZE, 
    num_workers = NUM_WORKER,
    pin_memory= True,
    shuffle = False
)

In [49]:
model

VisionTransformer(
  (conv_proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  (encoder): Encoder(
    (dropout): Dropout(p=0.0, inplace=False)
    (layers): Sequential(
      (encoder_layer_0): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (dropout): Dropout(p=0.0, inplace=False)
        (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=3072, out_features=768, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
      (encoder_layer_1): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_a

### Build the custom ViT for our flower dataset

In [50]:
class FlowerViT(nn.Module): 
    def __init__(self , base_model , num_of_classes): 
        super().__init__()
        
        self.base_model = base_model
        # freeze patch embedding
        for p in self.base_model.conv_proj.parameters(): 
            p.requires_grad = False

        # freeze the transformer encoders 
        for p in self.base_model.encoder.parameters(): 
            p.requires_grad = False
            
        # replace the classifier(Head)
        in_features = self.base_model.heads.head.in_features
        self.base_model.heads.head = nn.Linear(in_features , num_of_classes)
        
    def forward(self  , x):
        return self.base_model(x)

In [51]:
model.heads.head

Linear(in_features=768, out_features=1000, bias=True)

In [52]:
class_names

['daisy', 'dandelion', 'rose', 'sunflower', 'tulip']

In [53]:
vit = FlowerViT(base_model = model , num_of_classes = len(class_names)).to(device)

In [54]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(
    vit.parameters(),
    lr=3e-4,
    weight_decay=1e-4
)

In [56]:
EPOCHS = 5
for epoch in range(EPOCHS):
    vit.train()
    train_loss = 0.0
    train_correct = 0
    train_total = 0

    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)

        # forward
        outputs = vit(images)
        loss = criterion(outputs, labels)

        # backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # metrics
        train_loss += loss.item() * images.size(0)
        _, preds = torch.max(outputs, 1)
        train_correct += (preds == labels).sum().item()
        train_total += labels.size(0)

    train_loss /= train_total
    train_acc = train_correct / train_total

    
    vit.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0

    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(device)
            labels = labels.to(device)

            outputs = vit(images)
            loss = criterion(outputs, labels)

            val_loss += loss.item() * images.size(0)
            _, preds = torch.max(outputs, 1)
            val_correct += (preds == labels).sum().item()
            val_total += labels.size(0)

    val_loss /= val_total
    val_acc = val_correct / val_total

   
    print(
        f"Epoch [{epoch+1}/{EPOCHS}] | "
        f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | "
        f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}"
    )

Epoch [1/5] | Train Loss: 0.4848, Train Acc: 0.8694 | Val Loss: 0.4260, Val Acc: 0.8853
Epoch [2/5] | Train Loss: 0.3688, Train Acc: 0.9022 | Val Loss: 0.3588, Val Acc: 0.8957
Epoch [3/5] | Train Loss: 0.3088, Train Acc: 0.9174 | Val Loss: 0.3244, Val Acc: 0.9061
Epoch [4/5] | Train Loss: 0.2711, Train Acc: 0.9297 | Val Loss: 0.2993, Val Acc: 0.9131
Epoch [5/5] | Train Loss: 0.2440, Train Acc: 0.9357 | Val Loss: 0.2832, Val Acc: 0.9119


In [57]:
vit.eval()
correct , total = 0 , 0 
with torch.no_grad(): 
    for images , labels in test_loader: 
        images = images.to(device)
        labels = labels.to(device)

        outputs = vit(images)
        _ , pred = torch.max(outputs , 1)
        correct += (pred == labels).sum().item()
        total += labels.size(0)

print(f"Test acc: {correct / total}")

Test acc: 0.9236111111111112
