## ViT - Vision Transformer

* $\textbf{Vision Transformer (ViT)}$ model pre-trained on ImageNet-21k (14 million images, 21,843 classes) at resolution 224x224, and fine-tuned on ImageNet 2012 (1 million images, 1,000 classes) at resolution 224x224.

* Fine tuned on $\textbf{custom data set}$

* To do: Train model from scratch (model config - adjust patch size, dropout, ..)

### Config

In [1]:
EPOCHS = 10
BATCH_SIZE = 1
LEARNING_RATE = 1e-5
EVAL_BATCH = 1

PRETRAINED = True

AUGMENTATION = False  # data augmentation decreses the overall accuracy
AUGS = 1

GENDATA = True

### Preprocessing

Uncomment the cell below to resize the images to 224x224 to match $\textbf{google/vit-base-patch16-224}$.

We have to specify each directory seperately and run the below cell multiple times.

In [2]:
# import os
# from PIL import Image

# path = "./train/sit/"
# save_path = "./train_resized/sit/" # specify a directory for each class: sit, kie, run, walk/stand

# for root, subdirs, files in os.walk(path):
#     for f in files:
#         if f.endswith('jpeg'):
#             #print(f)
#             im = Image.open(path+f)
#             imResize = im.resize((224,224), Image.ANTIALIAS)
#             imResize.save(save_path + f, 'JPEG', quality=90)
             

### Dataset

In [3]:
import torchvision
from torchvision.transforms import ToTensor, v2
import matplotlib.pyplot as plt
import numpy as np
from tqdm.auto import tqdm
from transformers import ViTConfig, ViTModel
from transformers.modeling_outputs import SequenceClassifierOutput
import torch.nn as nn
import torch.nn.functional as F
from transformers import ViTFeatureExtractor
import torch.nn as nn
import torch
import torch.utils.data as data
from torch.autograd import Variable
import numpy as np
import tqdm

if GENDATA == False:
    # No eval data set to evaluate the model during training, due to data scarcity
    train_ds = torchvision.datasets.ImageFolder('./data/train_resized', transform=ToTensor())  # data is not on GitHub -> change data path
    test_ds = torchvision.datasets.ImageFolder('./data/test_resized', transform=ToTensor())
else:
    train_ds = torchvision.datasets.ImageFolder('./data/Stable_Diffusion_data/train_resized', transform=ToTensor())
    test_ds = torchvision.datasets.ImageFolder('./data/test_resized', transform=ToTensor())

print('Number of classes: ', len(train_ds.classes)), print('Training samples: ', len(train_ds)), print('Test samples: ', len(test_ds))
# samples per class
print('\nSamples per class in training set:')
for i in range(len(train_ds.classes)):
    print('Class {}: '.format(train_ds.classes[i]), len([x for x in train_ds.targets if x == i]))

print('\nSamples per class in test set:')
for i in range(len(test_ds.classes)):
    print('Class {}: '.format(test_ds.classes[i]), len([x for x in test_ds.targets if x == i]))

  from .autonotebook import tqdm as notebook_tqdm


Number of classes:  4
Training samples:  1180
Test samples:  87

Samples per class in training set:
Class lie:  275
Class run:  213
Class sit:  161
Class walk_stand:  531

Samples per class in test set:
Class lie:  10
Class run:  8
Class sit:  8
Class walk_stand:  61


### Data augmentation

* To do: cropping, flipping, ..

In [4]:
# data augmentation for training data

if AUGMENTATION == True:

    augmentations = AUGS # 1

    for i in range(augmentations):

        augmented_train_data = torchvision.datasets.ImageFolder('./data/train_resized', transform=v2.Compose([
            v2.RandomResizedCrop(size = (224,224), antialias=True),
            v2.RandomRotation(degrees=15),
            #v2.RandomVerticalFlip(), 
            v2.RandomHorizontalFlip(),
            v2.ToTensor(),
            #v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]))

        train_ds = data.ConcatDataset([train_ds, augmented_train_data])
        print('We added augmented data to the training set. The new length of the training set is: ', len(train_ds))
else:
    print('No data augmentation applied.')

# no data augmentation for test data

No data augmentation applied.


### Custom model

* To do: Implement ViT from huggingface - use model.config

In [5]:
# Initializing a ViT vit-base-patch16-224 style configuration
configuration = ViTConfig(patch_size=16)

# Initializing a model (with random weights) from the vit-base-patch16-224 style configuration
test_model = ViTModel(configuration)

# Accessing the model configuration
configuration = test_model.config

### Adjust $\textbf{google/vit-base-patch16-224}$

We add a dropout and a classification layer to the model to predict the classes (4 classes)

In [6]:
class ViTForImageClassification(nn.Module):
    def __init__(self, num_labels=4):
        super(ViTForImageClassification, self).__init__()
        if PRETRAINED == True:     
          self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224')
          # freeze the weights
          # for param in self.vit.parameters():
          #   param.requires_grad = False
          print('Model initialized with pre-trained weights.')
        else:
          self.vit = ViTModel(configuration)
          print('\nModel initialized with random weights.')
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(self.vit.config.hidden_size, num_labels)
        self.num_labels = num_labels

    def forward(self, pixel_values, labels):
        outputs = self.vit(pixel_values=pixel_values)
        output = self.dropout(outputs.last_hidden_state[:,0])
        logits = self.classifier(output)

        loss = None
        if labels is not None:
          loss_fct = nn.CrossEntropyLoss()
          loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
        if loss is not None:
          return logits, loss.item()
        else:
          return logits, None

In [7]:
# Define Model
model = ViTForImageClassification(4) #len(train_ds.classes)) 

feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')

optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

loss_func = nn.CrossEntropyLoss()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 

if torch.cuda.is_available():
    model.cuda() 

Some weights of the model checkpoint at google/vit-base-patch16-224 were not used when initializing ViTModel: ['classifier.weight', 'classifier.bias']
- This IS expected if you are initializing ViTModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.weight', 'vit.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Model initialized with pre-trained weights.


In [8]:
print("Number of train samples: ", len(train_ds))
print("Number of test samples: ", len(test_ds))
#print("Detected Classes are: ", train_ds.class_to_idx) 

train_loader = data.DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,  num_workers=4)
test_loader  = data.DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=4) 

# Train the model

for epoch in range(EPOCHS):

#for epoch in range(EPOCHS):  
#for epoch in tqdm(range(EPOCHS)):      
  for step, (x, y) in enumerate(train_loader):
    model.train()
    # Change input array into list with each batch being one element
    x = np.split(np.squeeze(np.array(x)), BATCH_SIZE)
    # Remove unecessary dimension
    for index, array in enumerate(x):
      x[index] = np.squeeze(array)
    # Apply feature extractor, stack back into 1 tensor and then convert to tensor
    x = torch.tensor(np.stack(feature_extractor(x)['pixel_values'], axis=0))
    # Send to GPU if available
    x, y  = x.to(device), y.to(device)
    b_x = Variable(x)   # batch x (image)
    b_y = Variable(y)   # batch y (target)
    # Feed through model
    output, loss = model(b_x, None)
    # Calculate loss
    if loss is None: 
      loss = loss_func(output, b_y)   
      optimizer.zero_grad()           
      loss.backward()                 
      optimizer.step()

Number of train samples:  1180
Number of test samples:  87


In [9]:
acc = []
class_1 = 0
class_2 = 0
class_3 = 0
class_4 = 0

# confusion_matrix = torch.zeros(4, 4)
with torch.no_grad():
    model.eval()
    for inputs, targets in test_loader:
        inputs = inputs[0].permute(1, 2, 0)
        inputs = torch.tensor(np.stack(feature_extractor(inputs)['pixel_values'], axis=0))

        # Send to appropriate computing device
        inputs = inputs.to(device)
        targets = targets.to(device)
      
        # Generate prediction
        prediction, loss = model(inputs, targets)
          
        # Predicted class value using argmax
        predicted_class = np.argmax(prediction.cpu())
        value_predicted = list(test_ds.class_to_idx.keys())[list(test_ds.class_to_idx.values()).index(predicted_class)]
        value_target = list(test_ds.class_to_idx.keys())[list(test_ds.class_to_idx.values()).index(targets)]
        # confusion matrix
        if predicted_class == targets:
            if predicted_class == 0:
                class_1 += 1
            elif predicted_class == 1:
                class_2 += 1
            elif predicted_class == 2:
                class_3 += 1
            elif predicted_class == 3:
                class_4 += 1
        accuracy = (predicted_class == targets).sum().item() / EVAL_BATCH
        acc.append(accuracy)

        # for t, p in zip(targets.view(-1), prediction.view(-1)):
        #     confusion_matrix[t.long(), p.long()] += 1
        #     false_positives = confusion_matrix.sum(axis=0) - np.diag(confusion_matrix)
        #     false_negatives = confusion_matrix.sum(axis=1) - np.diag(confusion_matrix)
        #     true_positives = np.diag(confusion_matrix)
        #     true_negatives = confusion_matrix.sum() - (false_positives + false_negatives + true_positives)


        # # calculate confusion matrix
        # confusion_matrix = torch.zeros(4, 4)
        # for t, p in zip(targets.view(-1), prediction.view(-1)):
        #     confusion_matrix[t.long(), p.long()] += 1
        # print(confusion_matrix)

print('Accuracy: ', np.mean(acc)*100)
print('\nTotal over all classes: ', class_1 + class_2 + class_3 + class_4, ' out of ', len(test_ds))
print('Class {}: '.format(test_ds.classes[0]), class_1, ' out of ', len([x for x in test_ds.targets if x == 0]))
print('Class {}: '.format(test_ds.classes[1]), class_2, ' out of ', len([x for x in test_ds.targets if x == 1]))
print('Class {}: '.format(test_ds.classes[2]), class_3, ' out of ', len([x for x in test_ds.targets if x == 2]))
print('Class {}: '.format(test_ds.classes[3]), class_4, ' out of ', len([x for x in test_ds.targets if x == 3]))
# print('Confusion matrix: ')
# print(confusion_matrix)


Accuracy:  93.10344827586206

Total over all classes:  81  out of  87
Class lie:  10  out of  10
Class run:  6  out of  8
Class sit:  7  out of  8
Class walk_stand:  58  out of  61


## Model from scratch

#### no data augmentation, no synthetic data

* Model 1:  71.26 
* Model 2:  71.26
* Model 3:  66.67
* Model 4:  70.11
* Model 5:  67.82

Total: 69.42

#### with data augmentation, no synthetic data

* Model 1:  70.11 
* Model 2:  70.11
* Model 3:  70.11

Total: 

Confusion Matrix example: 
Total over all classes:  62  out of  87
Class lie:  1  out of  10
Class run:  0  out of  8
Class sit:  1  out of  8
Class walk_stand:  60  out of  61

## Model pretrained on ImageNet

#### NO data augmentation and NO synthetic data  (10 epochs)

* Model 1:  93.10 
* Model 2:  91.95
* Model 3:  94.25
* Model 4:  91.95
* Model 5:  91.95
* Model 6:  93.10
* Model 7:  93.10
* Model 8:  91.95
* Model 9:  91.95
* Model 10: 93.10

Average accuracy: 92.64

## Model pretrained on ImageNet

#### WITH data augmentation and NO synthetic data

* Model 1:  86.21 
* Model 2:  87.36
* Model 3:  89.66
* Model 4:  89.66

Total: 88.22

Total over all classes:  78  out of  87
Class lie:  7  out of  10
Class run:  6  out of  8
Class sit:  6  out of  8
Class walk_stand:  59  out of  61    

### NO data augmentation and WITH synthetic data

* Model 1:  94.25 
* Model 2:  95.40
* Model 3:  94.25 

Total over all classes:  82  out of  87
Class lie:  10  out of  10
Class run:  6  out of  8
Class sit:  7  out of  8
Class walk_stand:  59  out of  61

### Save model

We save and provide the best model with an accuracy of 94.25 on the test set

In [10]:
# save model
# torch.save(model, './models/best_model.pt')

### Load model

In [11]:
# MODEL_PATH = './models/best_model.pt'
# model = torch.load(MODEL_PATH)
# model.eval()