## 🌴 **Created by** : Lucrece (Jahyun) Shin
## 🌴 **Latest Edit** : January 28, 2022
## 🌴 **Associated Blog Post** : 

# Import Libraries

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import time, os, pickle
from PIL import Image, ImageFile
import pandas as pd
import seaborn as sns
from collections import OrderedDict, deque
from sklearn.metrics import confusion_matrix, classification_report
import cv2 
from google.colab.patches import cv2_imshow
from glob import glob

import torch
from torch import nn, optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, Dataset
from torch.autograd import Variable
try:
    from torch.hub import load_state_dict_from_url
except ImportError:
    from torch.utils.model_zoo import load_url as load_state_dict_from_url

ImageFile.LOAD_TRUNCATED_IMAGES = True
use_cuda = torch.cuda.is_available()

# Import your Google drive if necessary

In [2]:
from google.colab import drive, files
drive.mount('/content/gdrive')

Mounted at /content/gdrive


# Import your data folders from your Google Drive

In [3]:
# run this cell once
!cp 'gdrive/My Drive/web_3cls+NOcutter.zip' . #web_2cls+benign.zip' . #'gdrive/My Drive/web_2cls+bag_SPLIT.zip' .
!unzip -qq web_3cls+NOcutter.zip
!rm web_3cls+NOcutter.zip 

!cp gdrive/MyDrive/Xray_2classes_cropped.zip .
!unzip -qq Xray_2classes_cropped.zip
!rm Xray_2classes_cropped.zip

# 3 classes : gun, knife, benign
!cp 'gdrive/My Drive/Xray-3cls_small.zip' .  # Xray-3cls_small_kitchen.zip
!unzip -qq Xray-3cls_small.zip
!rm Xray-3cls_small.zip 
# remove Kitchen knife & Cutter knife images
!rm Xray-3cls_small/knife/Kitchen*
!rm Xray-3cls_small/knife/Cutter*

# Define Dataloaders

In [4]:
# define transforms
transform = {
        'train' : transforms.Compose([transforms.Resize((224, 224)),
                                      transforms.RandomHorizontalFlip(), 
                                      transforms.RandomRotation(50),
                                      transforms.ToTensor()]),

        'valid' : transforms.Compose([transforms.Resize((224, 224)),
                                      transforms.RandomHorizontalFlip(),
                                      transforms.ToTensor()])
        }

* `root_dir` is a string that contains the name of your source domain data folder. It must contain 3 folders named 'train', 'valid', 'test', each containing folders named by class names.


In [None]:
root_dir = "web_3cls+NOcutter/" 
train_data = datasets.ImageFolder(root_dir + 'train', transform = transform['train'])
valid_data = datasets.ImageFolder(root_dir + 'valid', transform = transform['valid'])
test_data = datasets.ImageFolder(root_dir + 'test', transform = transform['valid'])
print("Class2idx: ", train_data.class_to_idx)
print('Train images :', len(train_data))
print('Valid images :', len(valid_data))
print('Test images :', len(test_data))
num_workers = 0
batch_size = 20
dataloaders = {}
dataloaders['train'] = DataLoader(train_data, batch_size=batch_size, num_workers=num_workers, shuffle=True)
dataloaders['valid'] = DataLoader(valid_data, batch_size=batch_size, num_workers=num_workers, shuffle=True)
dataloaders['test'] = DataLoader(test_data, batch_size=batch_size, num_workers=num_workers, shuffle=True)

# Define Multi-label Dataloaders (optional)
To use multi-label dataset as discussed in [my blog post](https://medium.com/mlearning-ai/ch-6-optimizing-data-for-flexible-and-robust-image-recognition-23f4dcce3af7#f4d3), you must organize your folders in the following way:

* [class a]+[class b] for images that contain both class a and b
* [class a] for image that contain only class a

For example, I originlly had 3 classes (gun, knife, benign) and made the following folders:

* gun
* gun+benign
* knife
* knife+benign
* gun+knife
* benign

Here I am only considering images that contain at most 2 classes. Please adjust code for a situation where you must consider image containing more than 2 classes.

In [None]:
class MultiLabelWebDataset(Dataset):
  def __init__(self, root_dir, classes, transform=None, soft_label_class_name=None, soft_label=0.5):
    # soft_label_class_name : name of the class that is given a soft label < 1
    self.root_dir = root_dir
    self.transform = transform
    self.classes = classes
    self.class_to_idx = {c:i for i, c in enumerate(self.classes)}
    self.soft_label_class_name = soft_label_class_name
    self.soft_label = soft_label  
    self.data = self.make_dataset()                                        
                                                      
  def __len__(self):
    return len(self.make_dataset())

  def make_dataset(self):
    instances = []
    for target_class in os.listdir(self.root_dir):      
      target_dir = os.path.join(self.root_dir, target_class)
      # split up the class names by "+" sign
      class_names = target_class.split("+") # list of length 1 or 2
      if not os.path.isdir(target_dir):
        continue
      for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
        for fname in sorted(fnames):  # for each image
          label = [0]*len(self.classes)
          path = os.path.join(self.root_dir, target_class, fname)

          if len(class_names)==1:  # images that contain only one class
            single_cls = class_names[0]
            if single_cls==self.soft_label_class_name:
              label[self.class_to_idx[single_cls]] = self.soft_label
            else:
              label[self.class_to_idx[single_cls]] = 1.

          elif len(class_names)==2:  # images that contain two classes
            for cls in class_names:
              if cls==self.soft_label_class_name:
                label[self.class_to_idx[cls]] = self.soft_label      
              else:
                label[self.class_to_idx[cls]] = 1.
          item = path, label
          instances.append(item)
    return instances

  def __getitem__(self, idx):
    if torch.is_tensor(idx):
      idx = idx.tolist()
    path, target = self.data[idx]
    image = Image.open(path).convert('RGB') 
    if self.transform:
      image = self.transform(image)
    return image, torch.tensor(target)

* `src_root_dir` is a string that contains the name of your source domain data folder. It must contain 3 folders named 'train', 'valid', 'test', each containing folders named by class names.
* `soft_label` is a weaker label less than 1. Usually a class is given a label of 1 for a one-hot-encoded label. If a less significant class (if any) is recieving a strong signal from the model (i.e. large recall), you can set a soft label of e.g. 0.5 instead of 1 for that class.

In [None]:
src_root_dir = "web_3cls+NOcutter"## insert your data folder here ##
classes=['benign','gun','knife']
soft_label_class_name="benign"
soft_label=0.5 # weaker label < 1 
train_data_multi_label = MultiLabelWebDataset(src_root_dir + '/train', classes=classes, transform = transform_, soft_label_class_name=soft_label_class_name, soft_label=soft_label)
valid_data_multi_label = MultiLabelWebDataset(src_root_dir + '/valid', classes=classes, transform = transform_, soft_label_class_name=soft_label_class_name, soft_label=soft_label)
test_data_multi_label = MultiLabelWebDataset(src_root_dir + '/test', classes=classes, transform = transform_, soft_label_class_name=soft_label_class_name, soft_label=soft_label)
print("Class2idx: ", train_data_multi_label.class_to_idx)
num_workers = 0
batch_size = 16
dataloaders_multi_label = {}
dataloaders_multi_label['train'] = DataLoader(train_data_multi_label, batch_size=batch_size, num_workers=num_workers, shuffle=True)
dataloaders_multi_label['valid'] = DataLoader(valid_data_multi_label, batch_size=batch_size, num_workers=num_workers, shuffle=True)
dataloaders_multi_label['test'] = DataLoader(test_data_multi_label, batch_size=batch_size, num_workers=num_workers, shuffle=True)
print('Train images :', len(train_data_multi_label), ", # of training batches:", len(dataloaders_multi_label['train']))
print('Valid images :', len(valid_data_multi_label), ", # of valid batches:", len(dataloaders_multi_label['valid']))
print('Test images :', len(test_data_multi_label), ", # of test batches:", len(dataloaders_multi_label['test']))

# Define ViT model from downloadable pre-trained checkpoints

In [5]:
!pip install ml_collections
!git clone https://github.com/jeonsworld/ViT-pytorch.git
!mv ViT-pytorch/* .

Collecting ml_collections
  Downloading ml_collections-0.1.1.tar.gz (77 kB)
[?25l[K     |████▏                           | 10 kB 24.4 MB/s eta 0:00:01[K     |████████▍                       | 20 kB 28.1 MB/s eta 0:00:01[K     |████████████▋                   | 30 kB 12.5 MB/s eta 0:00:01[K     |████████████████▉               | 40 kB 10.1 MB/s eta 0:00:01[K     |█████████████████████           | 51 kB 9.4 MB/s eta 0:00:01[K     |█████████████████████████▎      | 61 kB 9.5 MB/s eta 0:00:01[K     |█████████████████████████████▍  | 71 kB 9.4 MB/s eta 0:00:01[K     |████████████████████████████████| 77 kB 4.6 MB/s 
Building wheels for collected packages: ml-collections
  Building wheel for ml-collections (setup.py) ... [?25l[?25hdone
  Created wheel for ml-collections: filename=ml_collections-0.1.1-py3-none-any.whl size=94524 sha256=b420a6cc3bebfeffe1f619d525527312301318f4c2b435aff63f1677bf416c41
  Stored in directory: /root/.cache/pip/wheels/b7/da/64/33c926a1b10ff197910

In [6]:
from urllib.request import urlretrieve
from models.modeling import VisionTransformer, CONFIGS

os.makedirs("model_checkpoints", exist_ok=True)
if not os.path.isfile("model_checkpoints/ViT-B_16-224.npz"):
    urlretrieve("https://storage.googleapis.com/vit_models/imagenet21k+imagenet2012/ViT-B_16-224.npz", 
                "model_checkpoints/ViT-B_16-224.npz")

config = CONFIGS["ViT-B_16"]
model = VisionTransformer(config, num_classes=1000, zero_head=False, img_size=224, vis=True) 
model.load_from(np.load("model_checkpoints/ViT-B_16-224.npz"))

# Define Encoder & Classifier

In [7]:
encoder = nn.Sequential(*[model.transformer.embeddings, model.transformer.encoder])
ViT_embed_dim = 768  # final embedding dimension for ViT-B
n_classes = 3        # number of classes for my dataset
classifier = nn.Linear(ViT_embed_dim, n_classes)

In [8]:
!cp gdrive/MyDrive/encoder_1.pt .
!cp gdrive/MyDrive/classifier_1.pt .
encoder.load_state_dict(torch.load("encoder_1.pt"))
classifier.load_state_dict(torch.load("classifier_1.pt"))

<All keys matched successfully>

# Define ViT Fine-tuning Function

In [None]:
def fine_tune_ViT(encoder,     # ViT encoder (pre-trained)
                  classifier,  # single-layer fully-connected classifier
                  dataloaders, # dict with 2 keys, "train" and "valid", containing train & valid dataloaders
                  n_epochs,    # number of epochs to fine-tune
                  lr,          # learning rate
                  multi_label_data=False
                  ):
  
  encoder.cuda()
  classifier.cuda()

  #  1. Define optimizers and loss function  
  optimizer_encoder    = optim.Adam(encoder.parameters(), lr=lr, betas=(0.5, 0.9))
  optimizer_classifier = optim.Adam(classifier.parameters(), lr=lr, betas=(0.5, 0.9))
  ### Different loss criterion for multi-label and single-label data
  if multi_label_data:
    criterion = nn.BCEWithLogitsLoss() 
  else: # single-label data (standard)
    criterion = nn.CrossEntropyLoss()
  
  train_losses, val_losses, val_accs = [], [], []
  for e in range(n_epochs):
    #  2. Train on training data  
    encoder.train()
    classifier.train()
    train_loss = 0.
    for batch_i, (img, tgt) in enumerate(dataloaders['train']):
      img, tgt = img.cuda(), tgt.cuda()

      ## ViT encoder has 2 outputs, final embedding vectors for all image tokens and a stack of attention weights. ##
      ## Here we are not using attention weights during training/validation. ##
      ## embeddings: [batch_size, n_tokens, embedding dim]  e.g.[16, 197, 768] ##
      embeddings, att_weights = encoder(img) 

      ## Extract [CLS] token (at index 0) 's embeddings used for classification ##
      embedding_cls_token = embeddings[:, 0, :] # [batch_size, embedding dim] 

      logits = classifier(embedding_cls_token) # [batch_size, n_classes] 
      
      optimizer_encoder.zero_grad()
      optimizer_classifier.zero_grad()
      if multi_label_data:
        loss = criterion(logits.type(torch.FloatTensor), tgt.type(torch.FloatTensor))
      else:
        loss = criterion(logits.squeeze(-1).cuda(), tgt)
      loss.backward()
      optimizer_encoder.step()
      optimizer_classifier.step()
      train_loss += loss.item()

    #  3. Evaluate on valdiation data  
    encoder.eval()
    classifier.eval()
    val_loss = 0.
    for batch_i, (img, tgt) in enumerate(dataloaders['valid']):
      img, tgt = img.cuda(), tgt.cuda()
      with torch.no_grad():
        embeddings, att_weights = encoder(img) # embeddings: [batch_size, n_tokens, embedding dim]
        embedding_cls_token = embeddings[:, 0, :]  # [batch_size, embedding dim]
        logits = classifier(embedding_cls_token)   # [batch_size, n_classes]
        loss = criterion(logits.type(torch.FloatTensor), tgt.type(torch.FloatTensor))
        val_loss += loss.item()
    
    #  4. Log results and save model checkpoints 
    print("Epoch: {}/{}   Val CE Loss: {:.5f}".format(e+1, n_epochs, val_loss/len(dataloaders['valid'])))
    torch.save(encoder.state_dict(), 'encoder_{}.pt'.format(e+1))
    torch.save(classifier.state_dict(), 'classifier_{}.pt'.format(e+1))

  return encoder, classifier

# Fine-tune ViT!

In [None]:
encoder, classifier = fine_tune_ViT(encoder, 
                                    classifier, 
                                    dataloaders_multi_label, 
                                    n_epochs=4, 
                                    lr=3e-6,
                                    multi_label_data=True)