Import Libraries

In [1]:
# !pip uninstall transformers -y
!pip install transformers

Collecting transformers
  Downloading transformers-4.11.3-py3-none-any.whl (2.9 MB)
[K     |████████████████████████████████| 2.9 MB 4.2 MB/s 
[?25hCollecting sacremoses
  Downloading sacremoses-0.0.46-py3-none-any.whl (895 kB)
[K     |████████████████████████████████| 895 kB 38.4 MB/s 
Collecting pyyaml>=5.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 596 kB 48.3 MB/s 
Collecting tokenizers<0.11,>=0.10.1
  Downloading tokenizers-0.10.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (3.3 MB)
[K     |████████████████████████████████| 3.3 MB 32.6 MB/s 
Collecting huggingface-hub>=0.0.17
  Downloading huggingface_hub-0.0.19-py3-none-any.whl (56 kB)
[K     |████████████████████████████████| 56 kB 4.7 MB/s 
Installing collected packages: pyyaml, tokenizers, sacremoses, huggingface-hub, transformers
  Attempti

In [2]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import sklearn
from torchvision.datasets import CIFAR10, CIFAR100
from torchvision.transforms import ToTensor, Compose, Resize, Normalize
from torchvision.utils import make_grid
from torch.utils.data.dataloader import DataLoader
from transformers import BeitModel, BeitConfig, BeitFeatureExtractor, BeitForImageClassification
from torch.utils.data import random_split
from tqdm import tqdm
import sys
import torch.nn.functional as F
import time

In [3]:
import transformers
transformers.__version__

'4.11.3'

Download CIFAR100 dataset

In [4]:
feature_extractor = BeitFeatureExtractor.from_pretrained('microsoft/beit-base-patch16-224-pt22k-ft22k')

transform = Compose(
    [ToTensor(),
     Resize([224, 224]), 
     Normalize(feature_extractor.image_mean, feature_extractor.image_std)
    ]
)

dataset = CIFAR100(root='data/', download=True, transform=transform)
test_dataset = CIFAR100(root='data/', train=False, transform=transform)

Downloading:   0%|          | 0.00/276 [00:00<?, ?B/s]

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to data/cifar-100-python.tar.gz


  0%|          | 0/169001437 [00:00<?, ?it/s]

Extracting data/cifar-100-python.tar.gz to data/


Create DataLoader

In [5]:
def get_default_device():
    """Get GPU if available, else CPU"""
    if torch.cuda.is_available():
        return torch.device('cuda')
    else :
        return torch.device('cpu')

def to_device(data, device):
    if isinstance(data, (list, tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking = True)

class DeviceDataLoader:
    def __init__(self, dl, device):
        self.dl = dl
        self.device = device

    def __iter__(self):
        for b in self.dl:
            yield to_device(b, self.device)

    def __len__(self):
        return len(self.dl) 

device = get_default_device()

In [6]:
train_loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
train_loader = DeviceDataLoader(train_loader, device)
test_loader = DeviceDataLoader(test_loader, device)
len(train_loader), len(test_loader)

(1563, 313)

Create Model

In [11]:
def format_time_interval(t_diff):
    """
    A function to convert time in seconds to format hh:mm:ss
    """
    t_diff = int(t_diff)
    l = []
    while t_diff:
        s = str(int(t_diff % 60))
        if len(s) < 2: s = '0'+s
        l.append(s)
        t_diff = t_diff // 60
    if len(l) < 2: l.append('00')
    l.reverse()
    s = ':'.join(l)
    return s

class Model(nn.Module):
    """Classfication model"""
    def __init__(self, num_classes):
        super().__init__()
        self.num_classes = num_classes
        self.beit_model = BeitModel.from_pretrained('microsoft/beit-base-patch16-224-pt22k-ft22k')
        self.linear1 = nn.Linear(768, num_classes)

    def freeze_base_model(self):
        """
        method to freeze the layers of BEiT-model.
        This is used for transfer learning.
        """
        for name, parameter in model.named_parameters(prefix=''):
            if name.startswith('beit'):
                parameter.requires_grad = False

    def unfreeze_base_model(self):
        for name, parameter in model.named_parameters(prefix=''):
            if name.startswith('beit'):
                parameter.requires_grad = True
    
    def trainable_parameters(self):
        nn, non_trainable = 0, 0
        for param in model.parameters():
            if param.requires_grad:
                nn += np.prod(param.size())
            else:
                non_trainable += np.prod(param.size())
        print("Total Parameters: {:.2f}M".format((non_trainable + nn)/ 1e6))
        print("Non-trainable Prameters: {:.2f}M".format((non_trainable) / 1e6))
        print("Trainable Parameters: {:.2f}M".format(nn / 1e6))
        return nn
        
    def forward(self, xb):
        beit_output = self.beit_model(xb)
        out = self.linear1(beit_output.pooler_output)
        return out

    def predict(self, data):
        self.eval()
        with torch.no_grad():
            if isinstance(data, DeviceDataLoader):
                labels = []
                for xb, _ in tqdm(data):
                    out = self(xb)
                    _, batch_labels = torch.max(out, dim=-1)
                    labels.append(batch_labels)
                labels = torch.stack(labels)
            else:
                out = self(data)
                _, labels = torch.max(out, dim=-1)
        return labels.cpu()
    
    def acc(self, out, labels):
        _, preds = torch.max(out, dim = 1)
        acc = torch.mean((preds == labels).float())
        return acc
    
    def fit(self, train_loader, val_loader, optimizer, lr, epochs):
        opt = optimizer(self.parameters(), lr)
        history = {
            'loss': [], 
            'acc': [], 
            'val_loss': [], 
            'val_acc': []
        }
        
        for epoch in range(epochs):
            loss_epoch = 0.0
            acc_epoch = 0.0
            self.train()
            n = len(train_loader)
            t_start = time.time()
            for batch_i, (xb, labels) in enumerate(train_loader):
                loss, acc = self.step(xb, labels)
                loss.backward()
                opt.step()
                opt.zero_grad()
                loss_epoch = (loss_epoch * batch_i + loss.item()) / (batch_i + 1)
                acc_epoch = (acc_epoch * batch_i + acc.item()) / (batch_i + 1)
                t_diff= time.time() - t_start
                ett = (t_diff / (batch_i + 1)) * n
                s_t_diff = format_time_interval(t_diff)
                s_ett = format_time_interval(ett)
                sys.stdout.write(f"\rEpoch: [{epoch+1}/{epochs}]:[{batch_i+1}/{n}], ETA: [{s_t_diff}/{s_ett}] "
                                 f"loss: {loss_epoch:.3f}, acc: {acc_epoch:.3f}")
                sys.stdout.flush()
                
            val_loss, val_acc = self.evaluate(val_loader, is_training=True)
            history['loss'].append(loss_epoch)
            history['acc'].append(acc_epoch)
            history['val_loss'].append(val_loss)
            history['val_acc'].append(val_acc)
            self.epoch_end(epoch, history)
            if val_acc >= np.max(history['val_acc']):
                self.save_model(history)
        
        return history
            
    def save_model(self, history):
        path = "cifar100-beit.pth"
        print("Saving checkpoint")
        torch.save({
            'model_state_dict': model.state_dict(), 
            'test_acc': np.max(history['val_acc']),
            'num_classes': self.num_classes
        }, path)

    @classmethod
    def load_model(cls):
        path = "cifar100-beit.pth"
        checkpoint = torch.load(path)
        model = MnistModel(cls['num_classes'])
        model.load_state_dict(checkpoint['model_state_dict'])
        return model
    
    def step(self, xb, labels):
        out = self(xb)
        loss = F.cross_entropy(out, labels)
        acc = self.acc(out, labels)
        return loss, acc
    
    def evaluate(self, val_loader, is_training=True):
        val_loss = 0.0
        val_acc = 0.0
        self.eval()
        n = len(val_loader)
        t_start = time.time()
        with torch.no_grad():
            for batch_i, (xb, labels) in enumerate(val_loader):
                loss, acc = self.step(xb, labels)
                val_loss += loss.item()
                val_acc += acc.item()
                if not is_training:
                    t_diff= time.time() - t_start
                    ett = (t_diff / (batch_i + 1)) * n
                    s_t_diff = format_time_interval(t_diff)
                    s_ett = format_time_interval(ett)
                    sys.stdout.write(f"\rBatch: [{batch_i+1}/{n}], ETA: [{s_t_diff}/{s_ett}]")
                    sys.stdout.flush()

        val_loss /= len(val_loader)
        val_acc /= len(val_loader)
        return val_loss, val_acc
        
    def epoch_end(self, epoch, history):
        print(",  val_loss: {:.4f}, val_acc: {:.4f}".format(history['val_loss'][-1], history['val_acc'][-1]))

In [12]:
model = Model(100)
_ = model.to(device)    # Push model to cuda device(if available)

Some weights of the model checkpoint at microsoft/beit-base-patch16-224-pt22k-ft22k were not used when initializing BeitModel: ['classifier.bias', 'classifier.weight']
- This IS expected if you are initializing BeitModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BeitModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Transfer Learning

In [13]:
optimizer = torch.optim.Adam
lr = 3e-4
epochs = 2
model.freeze_base_model()
_ = model.trainable_parameters()
history = model.fit(train_loader, test_loader, optimizer, lr, epochs)

Total Parameters: 85.84M
Non-trainable Prameters: 85.76M
Trainable Parameters: 0.08M
Epoch: [1/2]:[1563/1563], ETA: [05:30/05:30] loss: 1.085, acc: 0.781,  val_loss: 0.4916, val_acc: 0.8675
Saving checkpoint
Epoch: [2/2]:[1563/1563], ETA: [05:30/05:30] loss: 0.453, acc: 0.874,  val_loss: 0.4231, val_acc: 0.8800
Saving checkpoint


Fine Tuning

In [14]:
optimizer = torch.optim.Adam
lr = 1e-5
epochs = 1
model.unfreeze_base_model()
_ = model.trainable_parameters()
history = model.fit(train_loader, test_loader, optimizer, lr, epochs)

Total Parameters: 85.84M
Non-trainable Prameters: 0.00M
Trainable Parameters: 85.84M
Epoch: [1/1]:[1563/1563], ETA: [15:46/15:46] loss: 0.238, acc: 0.927,  val_loss: 0.2418, val_acc: 0.9264
Saving checkpoint


Model gives a validation accuracy of 0.92 which matches with the results mentioned in the paper.