In [1]:
import os
import math
import time
import evaluate
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.optim as optim
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import ImageFolder

from transformers import ViTForImageClassification, TrainingArguments, Trainer, ViTImageProcessor

In [2]:
data_path = 'flower_photos'
dataset = ImageFolder(root=data_path)
num_samples = len(dataset)
classes = dataset.classes
num_classes = len(dataset.classes)

TRAIN_RATIO, VALID_RATIO = 0.8, 0.1
n_train_examples = int(num_samples * TRAIN_RATIO)
n_valid_examples = int(num_samples * VALID_RATIO)
n_test_examples = num_samples - n_train_examples - n_valid_examples
train_dataset, valid_dataset, test_dataset = random_split(
    dataset,
    [n_train_examples, n_valid_examples, n_test_examples]
)

In [3]:
IMG_SIZE = 224

train_transforms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

test_transforms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

train_dataset.dataset.transform = train_transforms
valid_dataset.dataset.transform = test_transforms
test_dataset.dataset.transform = test_transforms

In [4]:
BATCH_SIZE = 512

train_loader = DataLoader(
    train_dataset,
    shuffle=True,
    batch_size=BATCH_SIZE
)

val_loader = DataLoader(
    valid_dataset,
    batch_size=BATCH_SIZE
)

test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE
)

In [5]:
class TransformerEncoder(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1):
        super(TransformerEncoder, self).__init__()
        self.attn = nn.MultiheadAttention(
            embed_dim=embed_dim,
            num_heads=num_heads,
            batch_first=True
        )
        self.ffn = nn.Sequential(
            nn.Linear(in_features=embed_dim, out_features=ff_dim, bias=True),
            nn.ReLU(),
            nn.Linear(in_features=ff_dim, out_features=embed_dim, bias=True)
        )
        self.ln1 = nn.LayerNorm(normalized_shape=embed_dim, eps=1e-6)
        self.ln2 = nn.LayerNorm(normalized_shape=embed_dim, eps=1e-6)
        self.dropout1 = nn.Dropout(p=dropout)
        self.dropout2 = nn.Dropout(p=dropout)

    def forward(self, query, key, value):
        attn_output, _ = self.attn(query, key, value)
        attn_output = self.dropout1(attn_output)
        out1 = self.ln1(query + attn_output)
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output)
        out2 = self.ln2(out1 + ffn_output)
        return out2

In [6]:
class PatchPositionEmbedding(nn.Module):
    def __init__(self, img_size=224, embed_dim=512, patch_size=16, device='cpu'):
        super(PatchPositionEmbedding, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=embed_dim, kernel_size=patch_size, stride=patch_size, bias=False)
        scale = embed_dim ** -0.5
        self.positional_embedding = nn.Parameter(scale * torch.rand((img_size // patch_size) ** 2, embed_dim))
        self.device = device

    def forward(self, x):
        x = self.conv1(x)
        x = x.reshape(x.shape[0], x.shape[1], -1)
        x = x.permute(0, 2, 1)

        x = x + self.positional_embedding.to(self.device)
        return x

In [7]:
class VisionTransformerCls(nn.Module):
    def __init__(
        self,
        image_size,
        embed_dim,
        num_heads,
        ff_dim,
        dropout=0.1,
        device='cpu',
        num_classes=10,
        patch_size=16
    ):
        super(VisionTransformerCls, self).__init__()
        self.embed_layer = PatchPositionEmbedding(
            img_size=image_size,
            embed_dim=embed_dim,
            patch_size=patch_size,
            device=device
        )
        self.transformer_layer = TransformerEncoder(
            embed_dim,
            num_heads,
            ff_dim,
            dropout
        )
        self.fc1 = nn.Linear(in_features=embed_dim, out_features=20)
        self.fc2 = nn.Linear(in_features=20, out_features=num_classes)
        self.dropout = nn.Dropout(p=dropout)
        self.relu = nn.ReLU()

    def forward(self, x):
        output = self.embed_layer(x)
        output = self.transformer_layer(output, output, output)
        output = output[:, 0, :]
        output = self.dropout(output)
        output = self.fc1(output)
        output = self.dropout(output)
        output = self.fc2(output)
        return output

In [8]:
def train_epoch(model, optimizer, criterion, train_dataloader, device, epoch=0, log_interval=50):
    model.train()
    total_acc, total_count = 0, 0
    losses = []
    start_time = time.time()

    for idx, (inputs, labels) in enumerate(train_dataloader):
        inputs = inputs.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        predictions = model(inputs)

        loss = criterion(predictions, labels)
        losses.append(loss.item())

        loss.backward()
        optimizer.step()
        total_acc += (predictions.argmax(1) == labels).sum().item()
        total_count += labels.size(0)
        if idx % log_interval == 0 and idx > 0:
            elapsed = time.time() - start_time
            print(
                '| Epoch {:3d} | {:5d}/{:5d} batches '
                '| Accuracy {:8.3f}'.format(
                    epoch, idx, len(train_dataloader), total_acc / total_count
                )
            )
            total_acc, total_count = 0, 0
            start_time = time.time()

    epoch_acc = total_acc / total_count
    epoch_loss = sum(losses) / len(losses)
    return epoch_acc, epoch_loss

def evaluate_epoch(model, criterion, valid_dataloader, device):
    model.eval()
    total_acc, total_count = 0, 0
    losses =[]

    with torch.no_grad():
        for idx, (inputs, labels) in enumerate(valid_dataloader):
            inputs = inputs.to(device)
            labels = labels.to(device)

            predictions = model(inputs)

            loss = criterion(predictions, labels)
            losses.append(loss.item())

            total_acc += (predictions.argmax(1) == labels).sum().item()
            total_count += labels.size(0)

    epoch_acc = total_acc / total_count
    epoch_loss = sum(losses) / len(losses)
    return epoch_acc, epoch_loss

def train(model, model_name, save_model, optimizer, criterion, train_dataloader, valid_dataloader, num_epochs, device):
    train_accs, train_losses = [], []
    eval_accs, eval_losses = [], []
    best_loss_eval = 100
    times = []
    for epoch in range(1, num_epochs + 1):
        epoch_start_time = time.time()

        train_acc, train_loss = train_epoch(model, optimizer, criterion, train_dataloader, device, epoch)
        train_accs.append(train_acc)
        train_losses.append(train_loss)

        eval_acc, eval_loss = evaluate_epoch(model, criterion, valid_dataloader, device)
        eval_accs.append(eval_acc)
        eval_losses.append(eval_loss)

        if eval_loss < best_loss_eval:
            torch.save(model.state_dict(), save_model + f'/{model_name}.pt')

        times.append(time.time() - epoch_start_time)

        print('-' * 59)
        print(
            '| End of epoch {:3d} | Time: {:5.2f}s | Train Accuracy {:8.3f} | Train Loss {:8.3f} '
            '| Valid Accuracy {:8.3f} | Valid Loss {:8.3f} '.format(
                epoch, time.time() - epoch_start_time, train_acc, train_loss, eval_acc, eval_loss
            )
        )
        print('-' * 59)

    model.load_state_dict(torch.load(save_model + f'/{model_name}.pt'))
    model.eval()
    metrics = {
        'train_accuracy': train_accs,
        'train_loss': train_losses,
        'valid_accuracy': eval_accs,
        'valid_loss': eval_losses,
        'time': times
    }
    return model, metrics

In [9]:
def plot_result(num_epochs, train_accs, eval_accs, train_losses, eval_losses):
    epochs = list(range(num_epochs))
    fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(12,6))
    axs[0].plot(epochs, train_accs, label='Training')
    axs[0].plot(epochs, eval_accs, label='Evaluation')
    axs[1].plot(epochs, train_losses, label='Training')
    axs[1].plot(epochs, eval_losses, label='Evaluation')
    axs[0].set_xlabel('Epochs')
    axs[1].set_xlabel('Epochs')
    axs[0].set_ylabel('Accuracy')
    axs[1].set_ylabel('Loss')
    plt.legend()

In [10]:
image_size = 224
embed_dim = 512
num_heads = 4
ff_dim = 128
dropout = 0.1

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = VisionTransformerCls(
    image_size=image_size,
    embed_dim=512,
    num_heads=num_heads,
    ff_dim=ff_dim,
    dropout=dropout,
    num_classes=num_classes,
    device=device
)
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0005)

num_epochs = 100
save_model = 'vit_flowers'
os.makedirs(save_model, exist_ok=True)
model_name = 'vit_flowers'

model, metrics = train(
    model,
    model_name,
    save_model,
    optimizer,
    criterion,
    train_loader,
    val_loader,
    num_epochs,
    device
)

-----------------------------------------------------------
| End of epoch   1 | Time: 20.45s | Train Accuracy    0.287 | Train Loss    1.893 | Valid Accuracy    0.335 | Valid Loss    1.594 
-----------------------------------------------------------
-----------------------------------------------------------
| End of epoch   2 | Time: 18.46s | Train Accuracy    0.348 | Train Loss    1.508 | Valid Accuracy    0.362 | Valid Loss    1.419 
-----------------------------------------------------------
-----------------------------------------------------------
| End of epoch   3 | Time: 17.79s | Train Accuracy    0.372 | Train Loss    1.410 | Valid Accuracy    0.390 | Valid Loss    1.359 
-----------------------------------------------------------
-----------------------------------------------------------
| End of epoch   4 | Time: 18.43s | Train Accuracy    0.408 | Train Loss    1.330 | Valid Accuracy    0.422 | Valid Loss    1.297 
--------------------------------------------------------

  model.load_state_dict(torch.load(save_model + f'/{model_name}.pt'))


In [11]:
id2label = {id:label for id, label in enumerate(classes)}
label2id = {label:id for id, label in id2label.items()}

model = ViTForImageClassification.from_pretrained(
    'google/vit-base-patch16-224-in21k',
    num_labels=num_classes,
    id2label=id2label,
    label2id=label2id
).to(device)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/502 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [12]:
metric = evaluate.load('accuracy')

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return metric.compute(predictions=predictions, references=labels)

Downloading builder script:   0%|          | 0.00/4.20k [00:00<?, ?B/s]

In [13]:
feature_extractor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')

metric_name = 'accuracy'

args = TrainingArguments(
    f'vit_flowers',
    save_strategy='epoch',
    learning_rate=2e-5,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    num_train_epochs=10,
    weight_decay=0.01,
    eval_strategy='epoch',
    load_best_model_at_end=True,
    metric_for_best_model=metric_name,
    logging_dir='logs',
    remove_unused_columns=False,
    report_to='none'
)

def collate_fn(examples):
    pixel_values = torch.stack([example[0] for example in examples])
    labels = torch.tensor([example[1] for example in examples])
    return {'pixel_values': pixel_values, 'labels': labels}

preprocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]

In [14]:
trainer = Trainer(
    model,
    args,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    tokenizer=feature_extractor
)

  trainer = Trainer(


In [15]:
trainer.train()
outputs = trainer.predict(test_dataset)
outputs.metrics

Epoch,Training Loss,Validation Loss,Accuracy
1,No log,0.572757,0.972752
2,No log,0.245856,0.983651
3,No log,0.169789,0.983651
4,No log,0.142573,0.986376
5,No log,0.125625,0.980926
6,0.343800,0.113832,0.983651
7,0.343800,0.108867,0.983651
8,0.343800,0.103768,0.983651
9,0.343800,0.102015,0.983651
10,0.343800,0.100797,0.983651


{'test_loss': 0.15708990395069122,
 'test_accuracy': 0.9645776566757494,
 'test_runtime': 5.8921,
 'test_samples_per_second': 62.287,
 'test_steps_per_second': 2.037}