In [None]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import os

In [None]:
import torch
import torchvision
from torchvision import datasets
from torchvision import transforms as T # for simplifying the transforms
from torchvision.models import VisionTransformer
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, sampler, random_split
from torchvision import models

In [None]:
## Now, we import timm, torchvision image models
!pip install timm 
import timm
from timm.loss import LabelSmoothingCrossEntropy # This is better than normal nn.CrossEntropyLoss

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting timm
  Downloading timm-0.9.2-py3-none-any.whl (2.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m82.4 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub (from timm)
  Downloading huggingface_hub-0.14.1-py3-none-any.whl (224 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m224.5/224.5 kB[0m [31m31.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting safetensors (from timm)
  Downloading safetensors-0.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m89.0 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: safetensors, huggingface-hub, timm
Successfully installed huggingface-hub-0.14.1 safetensors-0.3.1 timm-0.9.2


In [None]:
import matplotlib.pyplot as plt
import sys
from tqdm import tqdm
import time
import copy

In [None]:
def get_classes(data_dir):
    all_data = datasets.ImageFolder(data_dir)
    return all_data.classes

In [None]:
def get_data_loaders(data_dir, batch_size, train = False):
    if train:
        #train
        transform = T.Compose([
            T.RandomHorizontalFlip(),
            T.RandomVerticalFlip(),
            T.RandomApply(torch.nn.ModuleList([T.ColorJitter()]), p=0.25),
            T.Resize(256),
            T.CenterCrop(224),
            T.ToTensor(),
            T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), # imagenet means
            T.RandomErasing(p=0.2, value='random')
        ])
        train_data = datasets.ImageFolder(os.path.join(data_dir, "train/"), transform = transform)
        train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=4)
        return train_loader, len(train_data)
    else:
        # val/test
        transform = T.Compose([ # We dont need augmentation for test transforms
            T.Resize(256),
            T.CenterCrop(224),
            T.ToTensor(),
            T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), # imagenet means
        ])
        val_data = datasets.ImageFolder(os.path.join(data_dir, "valid/"), transform=transform)
        test_data = datasets.ImageFolder(os.path.join(data_dir, "test/"), transform=transform)
        val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=True, num_workers=4)
        test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True, num_workers=4)
        return val_loader, test_loader, len(val_data), len(test_data)

In [None]:
dataset_path = "/content/drive/MyDrive/Vision_Transformer/butterly_dataset"

In [None]:
(train_loader, train_data_len) = get_data_loaders(dataset_path, 128, train=True)
(val_loader, test_loader, valid_data_len, test_data_len) = get_data_loaders(dataset_path, 32, train=False)



In [None]:
classes = get_classes("/content/drive/MyDrive/Vision_Transformer/butterly_dataset/train")
print(classes, len(classes))

['ADONIS', 'AFRICAN GIANT SWALLOWTAIL', 'AMERICAN SNOOT', 'AN 88', 'APPOLLO', 'ARCIGERA FLOWER MOTH', 'ATALA', 'ATLAS MOTH', 'BANDED ORANGE HELICONIAN', 'BANDED PEACOCK', 'BANDED TIGER MOTH', 'BECKERS WHITE', 'BIRD CHERRY ERMINE MOTH', 'BLACK HAIRSTREAK', 'BLUE MORPHO', 'BLUE SPOTTED CROW', 'BROOKES BIRDWING', 'BROWN ARGUS', 'BROWN SIPROETA', 'CABBAGE WHITE', 'CAIRNS BIRDWING', 'CHALK HILL BLUE', 'CHECQUERED SKIPPER', 'CHESTNUT', 'CINNABAR MOTH', 'CLEARWING MOTH', 'CLEOPATRA', 'CLODIUS PARNASSIAN', 'CLOUDED SULPHUR', 'COMET MOTH', 'COMMON BANDED AWL', 'COMMON WOOD-NYMPH', 'COPPER TAIL', 'CRECENT', 'CRIMSON PATCH', 'DANAID EGGFLY', 'EASTERN COMA', 'EASTERN DAPPLE WHITE', 'EASTERN PINE ELFIN', 'ELBOWED PIERROT', 'EMPEROR GUM MOTH', 'GARDEN TIGER MOTH', 'GIANT LEOPARD MOTH', 'GLITTERING SAPPHIRE', 'GOLD BANDED', 'GREAT EGGFLY', 'GREAT JAY', 'GREEN CELLED CATTLEHEART', 'GREEN HAIRSTREAK', 'GREY HAIRSTREAK', 'HERCULES MOTH', 'HUMMING BIRD HAWK MOTH', 'INDRA SWALLOW', 'IO MOTH', 'Iphiclus si

In [None]:
dataloaders = {
    "train": train_loader,
    "val": val_loader
}
dataset_sizes = {
    "train": train_data_len,
    "val": valid_data_len
}

In [None]:
print(len(train_loader), len(val_loader), len(test_loader))

99 16 16


In [None]:
print(train_data_len, valid_data_len, test_data_len)

12594 500 500


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [None]:
model = torch.hub.load("pytorch/vision",'vit_b_16', pretrained=True)

Downloading: "https://github.com/pytorch/vision/zipball/main" to /root/.cache/torch/hub/main.zip
Downloading: "https://download.pytorch.org/models/vit_b_16-c867db91.pth" to /root/.cache/torch/hub/checkpoints/vit_b_16-c867db91.pth
100%|██████████| 330M/330M [00:04<00:00, 76.7MB/s]


In [None]:
# Load the ViT model
model = timm.create_model('vit_base_patch16_224', pretrained=True)

# Freeze the model parameters
for param in model.parameters():
    param.requires_grad = False

# Get the number of input features for the new head
n_inputs = model.head.in_features

# Replace the model's head with a custom head
model.head = nn.Sequential(
    nn.Linear(n_inputs, 512),
    nn.ReLU(),
    nn.Dropout(0.3),
    nn.Linear(512, len(classes))
)


Downloading model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

In [None]:
print(model.head)

Sequential(
  (0): Linear(in_features=768, out_features=512, bias=True)
  (1): ReLU()
  (2): Dropout(p=0.3, inplace=False)
  (3): Linear(in_features=512, out_features=100, bias=True)
)


In [None]:
criterion = LabelSmoothingCrossEntropy()
criterion = criterion.to(device)
optimizer = optim.Adam(model.head.parameters(), lr=0.001)

In [None]:
# lr scheduler
exp_lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.97)

In [None]:
def train_model(model, criterion, optimizer, scheduler, num_epochs=30):
    since = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    # Initialize the lists to store the losses
    train_losses = []
    val_losses = []

    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print("-"*10)
        
        for phase in ['train', 'val']: # We do training and validation phase per epoch
            if phase == 'train':
                model.train() # model to training mode
            else:
                model.eval() # model to evaluate
            
            running_loss = 0.0
            running_corrects = 0.0
            
            for inputs, labels in tqdm(dataloaders[phase]):
                inputs = inputs.to(device)
                labels = labels.to(device)
                
                optimizer.zero_grad()
                
                with torch.set_grad_enabled(phase == 'train'): # no autograd makes validation go faster
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1) # used for accuracy
                    loss = criterion(outputs, labels)
                    
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
                
            if phase == 'train':
                scheduler.step() # step at end of epoch

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc =  running_corrects.double() / dataset_sizes[phase]

            # store the loss for each phase
            if phase == 'train':
                train_losses.append(epoch_loss)
            else:
                val_losses.append(epoch_loss)
            
            print("{} Loss: {:.4f} Acc: {:.4f}".format(phase, epoch_loss, epoch_acc))
            
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict()) # keep the best validation accuracy model
        print()
    time_elapsed = time.time() - since # slight error
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print("Best Val Acc: {:.4f}".format(best_acc))
    
    model.load_state_dict(best_model_wts)
    return model, train_losses, val_losses


In [None]:
model_ft, train_losses, val_losses = train_model(model, criterion, optimizer, exp_lr_scheduler)

Epoch 0/29
----------


100%|██████████| 99/99 [42:47<00:00, 25.94s/it]


train Loss: 2.1208 Acc: 0.6502


100%|██████████| 16/16 [01:42<00:00,  6.38s/it]


val Loss: 1.2816 Acc: 0.8900

Epoch 1/29
----------


100%|██████████| 99/99 [02:20<00:00,  1.42s/it]


train Loss: 1.2858 Acc: 0.8905


100%|██████████| 16/16 [00:06<00:00,  2.44it/s]


val Loss: 1.1621 Acc: 0.9240

Epoch 2/29
----------


100%|██████████| 99/99 [02:19<00:00,  1.40s/it]


train Loss: 1.1768 Acc: 0.9212


100%|██████████| 16/16 [00:06<00:00,  2.48it/s]


val Loss: 1.1069 Acc: 0.9200

Epoch 3/29
----------


100%|██████████| 99/99 [02:18<00:00,  1.40s/it]


train Loss: 1.1193 Acc: 0.9391


100%|██████████| 16/16 [00:07<00:00,  2.27it/s]


val Loss: 1.0977 Acc: 0.9300

Epoch 4/29
----------


100%|██████████| 99/99 [02:19<00:00,  1.41s/it]


train Loss: 1.0818 Acc: 0.9495


100%|██████████| 16/16 [00:06<00:00,  2.47it/s]


val Loss: 1.0414 Acc: 0.9520

Epoch 5/29
----------


100%|██████████| 99/99 [02:19<00:00,  1.41s/it]


train Loss: 1.0587 Acc: 0.9528


100%|██████████| 16/16 [00:06<00:00,  2.52it/s]


val Loss: 1.0474 Acc: 0.9560

Epoch 6/29
----------


100%|██████████| 99/99 [02:19<00:00,  1.41s/it]


train Loss: 1.0373 Acc: 0.9602


100%|██████████| 16/16 [00:06<00:00,  2.51it/s]


val Loss: 1.0346 Acc: 0.9500

Epoch 7/29
----------


100%|██████████| 99/99 [02:18<00:00,  1.40s/it]


train Loss: 1.0250 Acc: 0.9613


100%|██████████| 16/16 [00:06<00:00,  2.33it/s]


val Loss: 1.0519 Acc: 0.9420

Epoch 8/29
----------


100%|██████████| 99/99 [02:18<00:00,  1.40s/it]


train Loss: 1.0076 Acc: 0.9686


100%|██████████| 16/16 [00:06<00:00,  2.30it/s]


val Loss: 1.0208 Acc: 0.9560

Epoch 9/29
----------


100%|██████████| 99/99 [02:18<00:00,  1.40s/it]


train Loss: 0.9951 Acc: 0.9726


100%|██████████| 16/16 [00:06<00:00,  2.45it/s]


val Loss: 1.0200 Acc: 0.9540

Epoch 10/29
----------


100%|██████████| 99/99 [02:19<00:00,  1.41s/it]


train Loss: 0.9862 Acc: 0.9743


100%|██████████| 16/16 [00:06<00:00,  2.52it/s]


val Loss: 1.0337 Acc: 0.9500

Epoch 11/29
----------


100%|██████████| 99/99 [02:18<00:00,  1.40s/it]


train Loss: 0.9849 Acc: 0.9736


100%|██████████| 16/16 [00:06<00:00,  2.33it/s]


val Loss: 1.0105 Acc: 0.9480

Epoch 12/29
----------


100%|██████████| 99/99 [02:19<00:00,  1.41s/it]


train Loss: 0.9692 Acc: 0.9786


100%|██████████| 16/16 [00:06<00:00,  2.34it/s]


val Loss: 1.0152 Acc: 0.9420

Epoch 13/29
----------


100%|██████████| 99/99 [02:18<00:00,  1.40s/it]


train Loss: 0.9668 Acc: 0.9774


100%|██████████| 16/16 [00:06<00:00,  2.46it/s]


val Loss: 1.0222 Acc: 0.9560

Epoch 14/29
----------


100%|██████████| 99/99 [02:19<00:00,  1.41s/it]


train Loss: 0.9650 Acc: 0.9779


100%|██████████| 16/16 [00:06<00:00,  2.47it/s]


val Loss: 1.0068 Acc: 0.9660

Epoch 15/29
----------


100%|██████████| 99/99 [02:19<00:00,  1.41s/it]


train Loss: 0.9593 Acc: 0.9798


100%|██████████| 16/16 [00:06<00:00,  2.51it/s]


val Loss: 1.0133 Acc: 0.9560

Epoch 16/29
----------


100%|██████████| 99/99 [02:19<00:00,  1.41s/it]


train Loss: 0.9510 Acc: 0.9821


100%|██████████| 16/16 [00:06<00:00,  2.39it/s]


val Loss: 1.0262 Acc: 0.9540

Epoch 17/29
----------


100%|██████████| 99/99 [02:18<00:00,  1.40s/it]


train Loss: 0.9462 Acc: 0.9851


100%|██████████| 16/16 [00:07<00:00,  2.26it/s]


val Loss: 1.0219 Acc: 0.9480

Epoch 18/29
----------


100%|██████████| 99/99 [02:18<00:00,  1.40s/it]


train Loss: 0.9493 Acc: 0.9821


100%|██████████| 16/16 [00:06<00:00,  2.46it/s]


val Loss: 1.0149 Acc: 0.9520

Epoch 19/29
----------


100%|██████████| 99/99 [02:19<00:00,  1.41s/it]


train Loss: 0.9422 Acc: 0.9851


100%|██████████| 16/16 [00:06<00:00,  2.49it/s]


val Loss: 1.0132 Acc: 0.9560

Epoch 20/29
----------


100%|██████████| 99/99 [02:18<00:00,  1.40s/it]


train Loss: 0.9380 Acc: 0.9863


100%|██████████| 16/16 [00:06<00:00,  2.48it/s]


val Loss: 0.9974 Acc: 0.9700

Epoch 21/29
----------


100%|██████████| 99/99 [02:19<00:00,  1.41s/it]


train Loss: 0.9329 Acc: 0.9872


100%|██████████| 16/16 [00:06<00:00,  2.51it/s]


val Loss: 1.0028 Acc: 0.9620

Epoch 22/29
----------


100%|██████████| 99/99 [02:18<00:00,  1.40s/it]


train Loss: 0.9320 Acc: 0.9876


100%|██████████| 16/16 [00:06<00:00,  2.33it/s]


val Loss: 1.0073 Acc: 0.9660

Epoch 23/29
----------


100%|██████████| 99/99 [02:18<00:00,  1.40s/it]


train Loss: 0.9299 Acc: 0.9877


100%|██████████| 16/16 [00:07<00:00,  2.26it/s]


val Loss: 1.0139 Acc: 0.9540

Epoch 24/29
----------


100%|██████████| 99/99 [02:18<00:00,  1.40s/it]


train Loss: 0.9273 Acc: 0.9879


100%|██████████| 16/16 [00:06<00:00,  2.47it/s]


val Loss: 1.0099 Acc: 0.9540

Epoch 25/29
----------


100%|██████████| 99/99 [02:18<00:00,  1.40s/it]


train Loss: 0.9220 Acc: 0.9891


100%|██████████| 16/16 [00:06<00:00,  2.45it/s]


val Loss: 1.0094 Acc: 0.9580

Epoch 26/29
----------


100%|██████████| 99/99 [02:20<00:00,  1.42s/it]


train Loss: 0.9194 Acc: 0.9893


100%|██████████| 16/16 [00:06<00:00,  2.49it/s]


val Loss: 1.0026 Acc: 0.9540

Epoch 27/29
----------


100%|██████████| 99/99 [02:18<00:00,  1.40s/it]


train Loss: 0.9223 Acc: 0.9888


100%|██████████| 16/16 [00:07<00:00,  2.28it/s]


val Loss: 1.0005 Acc: 0.9580

Epoch 28/29
----------


100%|██████████| 99/99 [02:18<00:00,  1.40s/it]


train Loss: 0.9189 Acc: 0.9898


100%|██████████| 16/16 [00:06<00:00,  2.46it/s]


val Loss: 1.0067 Acc: 0.9660

Epoch 29/29
----------


100%|██████████| 99/99 [02:18<00:00,  1.40s/it]


train Loss: 0.9184 Acc: 0.9920


100%|██████████| 16/16 [00:06<00:00,  2.50it/s]

val Loss: 0.9942 Acc: 0.9580

Training complete in 114m 57s
Best Val Acc: 0.9700





In [None]:

# Plotting the training and validation loss
plt.figure(figsize=(10,5))
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("Train and Validation Losses Over Epochs", fontsize=14)
plt.legend()
plt.grid(True)
plt.show()


In [None]:
test_loss = 0.0
class_correct = list(0 for i in range(len(classes)))
class_total = list(0 for i in range(len(classes)))
model_ft.eval()

for data, target in tqdm(test_loader):
    data, target = data.to(device), target.to(device)
    with torch.no_grad(): # turn off autograd for faster testing
        output = model_ft(data)
        loss = criterion(output, target)
    test_loss = loss.item() * data.size(0)
    _, pred = torch.max(output, 1)
    correct_tensor = pred.eq(target.data.view_as(pred))
    correct = np.squeeze(correct_tensor.cpu().numpy())
    if len(target) == 32:
        for i in range(32):
            label = target.data[i]
            class_correct[label] += correct[i].item()
            class_total[label] += 1

test_loss = test_loss / test_data_len
print('Test Loss: {:.4f}'.format(test_loss))
for i in range(len(classes)):
    if class_total[i] > 0:
        print("Test Accuracy of %5s: %2d%% (%2d/%2d)" % (
            classes[i], 100*class_correct[i]/class_total[i], np.sum(class_correct[i]), np.sum(class_total[i])
        ))
    else:
        print("Test accuracy of %5s: NA" % (classes[i]))
print("Test Accuracy of %2d%% (%2d/%2d)" % (
            100*np.sum(class_correct)/np.sum(class_total), np.sum(class_correct), np.sum(class_total)
        ))

100%|██████████| 16/16 [01:35<00:00,  5.95s/it]

Test Loss: 0.0355
Test Accuracy of ADONIS: 100% ( 5/ 5)
Test Accuracy of AFRICAN GIANT SWALLOWTAIL: 100% ( 5/ 5)
Test Accuracy of AMERICAN SNOOT: 100% ( 5/ 5)
Test Accuracy of AN 88: 100% ( 5/ 5)
Test Accuracy of APPOLLO: 100% ( 5/ 5)
Test Accuracy of ARCIGERA FLOWER MOTH: 100% ( 5/ 5)
Test Accuracy of ATALA: 100% ( 5/ 5)
Test Accuracy of ATLAS MOTH: 100% ( 5/ 5)
Test Accuracy of BANDED ORANGE HELICONIAN: 100% ( 5/ 5)
Test Accuracy of BANDED PEACOCK: 100% ( 5/ 5)
Test Accuracy of BANDED TIGER MOTH: 100% ( 4/ 4)
Test Accuracy of BECKERS WHITE: 80% ( 4/ 5)
Test Accuracy of BIRD CHERRY ERMINE MOTH: 100% ( 5/ 5)
Test Accuracy of BLACK HAIRSTREAK: 100% ( 5/ 5)
Test Accuracy of BLUE MORPHO: 80% ( 4/ 5)
Test Accuracy of BLUE SPOTTED CROW: 100% ( 5/ 5)
Test Accuracy of BROOKES BIRDWING: 100% ( 4/ 4)
Test Accuracy of BROWN ARGUS: 75% ( 3/ 4)
Test Accuracy of BROWN SIPROETA: 100% ( 4/ 4)
Test Accuracy of CABBAGE WHITE: 100% ( 5/ 5)
Test Accuracy of CAIRNS BIRDWING: 100% ( 5/ 5)
Test Accuracy of 


