<a href="https://colab.research.google.com/github/AriPathak/DinoV2_Flowers102/blob/main/DINOV2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
!pip install vit_pytorch
from vit_pytorch import ViT, Dino
import torchvision
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torchvision.transforms import ToTensor
from skimage import io
from torchvision import datasets, transforms, models
import tarfile
from torchvision.datasets.utils import download_url
from torch.utils.data import random_split, TensorDataset, Dataset
import torch.nn as nn
import seaborn as sns
from PIL import Image
from collections import OrderedDict
from torch import optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
import pandas as pd

Collecting vit_pytorch
  Downloading vit_pytorch-1.6.5-py3-none-any.whl (100 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m100.3/100.3 kB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting einops>=0.7.0 (from vit_pytorch)
  Downloading einops-0.7.0-py3-none-any.whl (44 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.6/44.6 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: einops, vit_pytorch
Successfully installed einops-0.7.0 vit_pytorch-1.6.5


In [None]:
dinov2_vits14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
test_img = torch.randn(1, 3, 224, 224)
dinov2_vits14(test_img).shape

Downloading: "https://github.com/facebookresearch/dinov2/zipball/main" to /root/.cache/torch/hub/main.zip
Downloading: "https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_pretrain.pth" to /root/.cache/torch/hub/checkpoints/dinov2_vits14_pretrain.pth
100%|██████████| 84.2M/84.2M [00:00<00:00, 190MB/s]


torch.Size([1, 384])

In [None]:
#for name, layer in dinov2_vits14.named_modules():
  #print(layer)

learnable_modules = ['blocks.8',
                    'blocks.9',
                    'blocks.10',
                    'blocks.11']
dinov2_vits14.patch_embed.proj = nn.Conv2d(4, 384, kernel_size=(14, 14), stride=(14, 14))
dinov2_vits14.requires_grad_(False)
modules = dict(dinov2_vits14.named_modules())
for name in learnable_modules:
    modules[name].requires_grad_(True)

In [None]:
class L2Norm(nn.Module):
    def forward(self, x, eps = 1e-6):
        norm = x.norm(dim = 1, keepdim = True).clamp(min = eps)
        return x / norm

class MLP(nn.Module):
  def __init__(self, dim, num_classes):
    super().__init__()
    self.layer = nn.Linear(dim, num_classes)
  def forward(self, x):
    return self.layer(x)

In [None]:
dataset_url = "https://s3.amazonaws.com/fast-ai-imageclas/oxford-102-flowers.tgz"
download_url(dataset_url, '.')

with tarfile.open('./oxford-102-flowers.tgz', 'r:gz') as tar:
    tar.extractall(path='./data')

img_path='./data/oxford-102-flowers/jpg/image_00001.jpg'

class flowersmodel(Dataset):
  def __init__(self,excel_file,root_dir,transform=None):
    self.annotations=pd. read_csv(excel_file,delimiter=' ')
    self.root_dir=root_dir
    self.transform=transform
  def __len__(self):
    return len(self.annotations)
  def __getitem__(self,index):
    img_path=os.path.join(self.root_dir,self.annotations.iloc[index,0])
    image=io.imread(img_path)
    y_label=torch.tensor(self.annotations.iloc[index,1])
    image=Image.open(img_path).resize((300,300),resample=0)
    if self.transform:
      image=self.transform(image)
    return (image,y_label)

data_transforms = {
    'train': transforms.Compose([
        transforms.RandomRotation(75),
        transforms.Resize(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ]),
    'validation': transforms.Compose([
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ]),
    'testing': transforms.Compose([
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ])
}


train_dataset=flowersmodel('./data/oxford-102-flowers/train.txt',root_dir='./data/oxford-102-flowers',transform=data_transforms['train'])
test_dataset=flowersmodel('./data/oxford-102-flowers/test.txt',root_dir='./data/oxford-102-flowers',transform=data_transforms['testing'])
val_dataset=flowersmodel('./data/oxford-102-flowers/valid.txt',root_dir='./data/oxford-102-flowers',transform=data_transforms['validation'])



Using downloaded and verified file: ./oxford-102-flowers.tgz


In [None]:
dataloaders = {
    'training' : torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True),
    'testing' : torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False),
    'validation' : torch.utils.data.DataLoader(val_dataset, batch_size=64, shuffle=True)
}

In [None]:
def display(inp):
  B, C, H, W = inp.shape
  if (C != 3):
    inp = inp.permute(0, 3, 2, 1)
  plt.imshow(inp)
  plt.show()

In [None]:
class ViT_Dino(nn.Module):
  def __init__(self, num_classes, in_feats=384):
    super().__init__()
    self.encoder = dinov2_vits14
    self.classifier = MLP(in_feats, num_classes)
  def forward(self, x):
    #x = L2Norm(x)
    x = self.encoder(x)
    return self.classifier(x)


In [None]:
from tqdm.notebook import tqdm
import time
import os
device = "cuda:0"
def pretrain(model, learner, train_loader, optimizer, epochs, PATH=None):
  model.to(device)
  model.train()
  for epoch in tqdm(range(1, epochs + 1)):
    check = 0
    for images, labels in tqdm(train_loader, leave=False):
      images, labels = images.to(device), labels.to(device)
      optimizer.zero_grad()
      loss = learner(images)
      avg_loss += loss
      loss.backward()
      optimizer.step()
      learner.update_moving_average()
      if check % 100 == 0:
        print(loss.detach())
    check += 1
  if PATH:
    torch.save(model.state_dict(), PATH)

def compute_accuracy(model, loader):
    total_correct = 0

    model.cuda()
    model.eval()
    for inputs, labels in tqdm(loader, leave=False):
        inputs, labels = inputs.cuda(), labels.cuda()
        output = model(inputs)
        _, pred = torch.max(output, 1)
        for d in zip(pred, labels):
          if d[0].item() == d[1].item():
            total_correct += 1
    return total_correct / len(loader.dataset)

def finetune(model, train_loader, val_loader, num_epochs, criterion, optimizer, path=None, scheduler=None, pretrained=False, pretrained_path=None):
    print('beginning to train model')
    if path and not os.path.exists(path):
      os.makedirs(path)
    if pretrained and (pretrained_path is not None):
      model.load_state_dict(torch.load(pretrained_path))
    model.to(device)
    for epoch in tqdm(range(1, num_epochs + 1)):
        total_loss = 0
        start_time = time.perf_counter()
        model.train()
        for inputs, labels in tqdm(train_loader, leave=False):
          inputs, labels = inputs.to(device), labels.to(device)
          optimizer.zero_grad()
          output = model(inputs)
          loss = criterion(output, labels)
          loss.backward()
          optimizer.step()
          total_loss += loss
        if path:
          torch.save(model.state_dict(), f'{path}/model_ep_{epoch:02d}.pth')
        end_time = time.perf_counter()
        duration = end_time - start_time

        train_acc = compute_accuracy(model, val_loader)



        current_lr = optimizer.param_groups[0]['lr']

        if scheduler and current_lr > 5e-5:
            scheduler.step()

        print(f'epoch {epoch:2}',
              f'loss: {total_loss:.3f}',
              f'time: {duration:.3f}',
              f'val acc: {train_acc:.4f}',
              sep='\t')

In [None]:
category = {}
cat_to_name = {"21": "fire lily", "3": "canterbury bells", "45": "bolero deep blue", "1": "pink primrose", "34": "mexican aster", "27": "prince of wales feathers", "7": "moon orchid", "16": "globe-flower", "25": "grape hyacinth", "26": "corn poppy", "79": "toad lily", "39": "siam tulip", "24": "red ginger", "67": "spring crocus", "35": "alpine sea holly", "32": "garden phlox", "10": "globe thistle", "6": "tiger lily", "93": "ball moss", "33": "love in the mist", "9": "monkshood", "102": "blackberry lily", "14": "spear thistle", "19": "balloon flower", "100": "blanket flower", "13": "king protea", "49": "oxeye daisy", "15": "yellow iris", "61": "cautleya spicata", "31": "carnation", "64": "silverbush", "68": "bearded iris", "63": "black-eyed susan", "69": "windflower", "62": "japanese anemone", "20": "giant white arum lily", "38": "great masterwort", "4": "sweet pea", "86": "tree mallow", "101": "trumpet creeper", "42": "daffodil", "22": "pincushion flower", "2": "hard-leaved pocket orchid", "54": "sunflower", "66": "osteospermum", "70": "tree poppy", "85": "desert-rose", "99": "bromelia", "87": "magnolia", "5": "english marigold", "92": "bee balm", "28": "stemless gentian", "97": "mallow", "57": "gaura", "40": "lenten rose", "47": "marigold", "59": "orange dahlia", "48": "buttercup", "55": "pelargonium", "36": "ruby-lipped cattleya", "91": "hippeastrum", "29": "artichoke", "71": "gazania", "90": "canna lily", "18": "peruvian lily", "98": "mexican petunia", "8": "bird of paradise", "30": "sweet william", "17": "purple coneflower", "52": "wild pansy", "84": "columbine", "12": "colt's foot", "11": "snapdragon", "96": "camellia", "23": "fritillary", "50": "common dandelion", "44": "poinsettia", "53": "primula", "72": "azalea", "65": "californian poppy", "80": "anthurium", "76": "morning glory", "37": "cape flower", "56": "bishop of llandaff", "60": "pink-yellow dahlia", "82": "clematis", "58": "geranium", "75": "thorn apple", "41": "barbeton daisy", "95": "bougainvillea", "43": "sword lily", "83": "hibiscus", "78": "lotus lotus", "88": "cyclamen", "94": "foxglove", "81": "frangipani", "74": "rose", "89": "watercress", "73": "water lily", "46": "wallflower", "77": "passion flower", "51": "petunia"}
for key in cat_to_name:
  k = int(key) - 1
  category[f"{k}"] = cat_to_name[f'{key}']

In [None]:
epochs = 30
PATH = 'DINO_ViT_test.pth'
MODEL_PATH = "Finetuned_DINO_Vit.pth"
criterion = nn.CrossEntropyLoss()
vit = ViT_Dino(102)
optimizer = torch.optim.Adam(vit.parameters(), lr=1e-4)

finetune(vit, dataloaders['training'], dataloaders['validation'], 25, criterion, optimizer, MODEL_PATH, False)

beginning to train model


  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

epoch  1	loss: 45.328	time: 8.929	val acc: 0.8449


  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

epoch  2	loss: 4.283	time: 8.844	val acc: 0.9352


  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

epoch  3	loss: 0.869	time: 9.159	val acc: 0.9647


  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

epoch  4	loss: 0.276	time: 8.222	val acc: 0.9657


  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

epoch  5	loss: 0.113	time: 9.179	val acc: 0.9725


  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

epoch  6	loss: 0.075	time: 8.601	val acc: 0.9725


  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

epoch  7	loss: 0.092	time: 9.242	val acc: 0.9755


  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

epoch  8	loss: 0.177	time: 9.218	val acc: 0.9637


  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

epoch  9	loss: 0.098	time: 8.310	val acc: 0.9578


  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

epoch 10	loss: 0.086	time: 9.242	val acc: 0.9706


  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

epoch 11	loss: 0.053	time: 8.454	val acc: 0.9755


  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

epoch 12	loss: 0.045	time: 9.231	val acc: 0.9764


  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

epoch 13	loss: 0.037	time: 9.160	val acc: 0.9715


  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

In [None]:
correct = 0
count = 0
for inputs, labels in dataloaders['testing']:
  vit.cuda()
  inputs, labels = inputs.cuda(), labels.cuda()
  output = vit(inputs)
  _, pred = torch.max(output, 1)
  predicted = pred.item()
  ground_truth = labels.item()
  print(ground_truth, predicted)

80 80
58 58
0 0
45 45
100 100
89 89
42 42
50 50
6 6
23 23
76 76
36 36
73 73
72 72
20 20
36 36
51 51
74 74
74 74
40 58
86 86
35 35
93 93
31 31
99 99
54 54
77 77
50 50
27 27
87 87
22 22
80 80
44 44
72 72
68 68
50 50
97 97
22 22
29 29
41 41
79 79
87 87
37 37
93 93
76 76
76 76
57 57
77 77
75 75
45 45
46 46
94 94
58 58
67 67
99 99
97 89
72 72
59 59
72 72
72 72
46 46
63 63
71 71
19 19
81 81
13 13
67 67
74 74
90 90
84 84
82 82
50 50
70 70
72 77


KeyboardInterrupt: 

In [None]:
PATH = "" #path to a dowloaded .jpg test image
image = Image.open('/content/drive/MyDrive/COSMOSBIPINNATUSSensationCosmo_1.jpg').resize((300,300),resample=0)
plt.imshow(image)
plt.show()
img_trans = transforms.Compose([
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ])
image = img_trans(image)
image = torch.reshape(image, (1, 3, 224, 224))
plt.imshow(torch.reshape(image, (224, 224, 3)))
plt.show()
vit.eval()
vit.cuda()
image = image.cuda()
output = vit(image)
_, pred = torch.max(output, 1)
print(pred)
print("MODEL PREDICTION: ", category[f"{pred.item()}"])