<a href="https://colab.research.google.com/github/AmbiTyga/Bio-VI-BERT/blob/main/Multi-label-Classifier.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@title EarlyStopping
class EarlyStopping(object):
    def __init__(self, mode='min', min_delta=0, patience=10, percentage=False):
        self.mode = mode
        self.min_delta = min_delta
        self.patience = patience
        self.best = None
        self.num_bad_epochs = 0
        self.is_better = None
        self._init_is_better(mode, min_delta, percentage)

        if patience == 0:
            self.is_better = lambda a, b: True
            self.step = lambda a: False

    def step(self, metrics):
        if self.best is None:
            self.best = metrics
            return False

        if torch.isnan(metrics):
            return True

        if self.is_better(metrics, self.best):
            self.num_bad_epochs = 0
            self.best = metrics
        else:
            self.num_bad_epochs += 1

        if self.num_bad_epochs >= self.patience:
            return True

        return False

    def _init_is_better(self, mode, min_delta, percentage):
        if mode not in {'min', 'max'}:
            raise ValueError('mode ' + mode + ' is unknown!')
        if not percentage:
            if mode == 'min':
                self.is_better = lambda a, best: a < best - min_delta
            if mode == 'max':
                self.is_better = lambda a, best: a > best + min_delta
        else:
            if mode == 'min':
                self.is_better = lambda a, best: a < best - (
                            best * min_delta / 100)
            if mode == 'max':
                self.is_better = lambda a, best: a > best + (
                            best * min_delta / 100)

In [1]:
!pip install timm==0.4.5 -q

[?25l[K     |█▏                              | 10kB 22.0MB/s eta 0:00:01[K     |██▎                             | 20kB 24.9MB/s eta 0:00:01[K     |███▍                            | 30kB 12.0MB/s eta 0:00:01[K     |████▋                           | 40kB 9.5MB/s eta 0:00:01[K     |█████▊                          | 51kB 4.0MB/s eta 0:00:01[K     |██████▉                         | 61kB 4.5MB/s eta 0:00:01[K     |████████                        | 71kB 5.0MB/s eta 0:00:01[K     |█████████▏                      | 81kB 5.2MB/s eta 0:00:01[K     |██████████▎                     | 92kB 5.6MB/s eta 0:00:01[K     |███████████▍                    | 102kB 5.9MB/s eta 0:00:01[K     |████████████▌                   | 112kB 5.9MB/s eta 0:00:01[K     |█████████████▊                  | 122kB 5.9MB/s eta 0:00:01[K     |██████████████▉                 | 133kB 5.9MB/s eta 0:00:01[K     |████████████████                | 143kB 5.9MB/s eta 0:00:01[K     |█████████████████       

In [2]:
!wget https://raw.githubusercontent.com/AmbiTyga/Bio-VI-BERT/main/Train.7z
!7z x /content/Train.7z

--2021-03-22 04:00:19--  https://raw.githubusercontent.com/AmbiTyga/Bio-VI-BERT/main/Train.7z
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.110.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 21599892 (21M) [application/octet-stream]
Saving to: ‘Train.7z’


2021-03-22 04:00:21 (44.8 MB/s) - ‘Train.7z’ saved [21599892/21599892]


7-Zip [64] 16.02 : Copyright (c) 1999-2016 Igor Pavlov : 2016-05-21
p7zip Version 16.02 (locale=en_US.UTF-8,Utf16=on,HugeFiles=on,64 bits,4 CPUs Intel(R) Xeon(R) CPU @ 2.30GHz (306F0),ASM,AES-NI)

Scanning the drive for archives:
  0M Scan /content/                   1 file, 21599892 bytes (21 MiB)

Extracting archive: /content/Train.7z
--
Path = /content/Train.7z
Type = 7z
Physical Size = 21599892
Headers Size = 14592
Method = LZMA2:24
Solid = 

In [3]:
!wget https://raw.githubusercontent.com/AmbiTyga/Bio-VI-BERT/main/Val.7z
!7z x /content/Val.7z

--2021-03-22 04:00:22--  https://raw.githubusercontent.com/AmbiTyga/Bio-VI-BERT/main/Val.7z
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.110.133, 185.199.109.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 4849556 (4.6M) [application/octet-stream]
Saving to: ‘Val.7z’


2021-03-22 04:00:24 (34.4 MB/s) - ‘Val.7z’ saved [4849556/4849556]


7-Zip [64] 16.02 : Copyright (c) 1999-2016 Igor Pavlov : 2016-05-21
p7zip Version 16.02 (locale=en_US.UTF-8,Utf16=on,HugeFiles=on,64 bits,4 CPUs Intel(R) Xeon(R) CPU @ 2.30GHz (306F0),ASM,AES-NI)

Scanning the drive for archives:
  0M Scan /content/                   1 file, 4849556 bytes (4736 KiB)

Extracting archive: /content/Val.7z
--
Path = /content/Val.7z
Type = 7z
Physical Size = 4849556
Headers Size = 2370
Method = LZMA2:6m
Solid = +
Blocks = 1


In [4]:
import timm
import pandas as pd
import numpy as np
import torch.nn as nn
import torch
from torch.utils.data import Dataset, DataLoader, sampler
from sklearn.preprocessing import LabelEncoder
import os,time
from math import fsum
from sklearn.model_selection import train_test_split
from glob import glob
from PIL import Image
from torchvision import transforms

## Cleaning and getting a final data

In [None]:
from PIL import Image
from glob import glob
images = [x for x in glob('./val/*/*') if 'val.csv' not in x]
for i in images:
  Image.open(i).convert('RGB').save(i)

In [5]:
labels = ["phylum","class","species","form","sample"]

train = pd.read_csv('/content/train/train.csv')
val = pd.read_csv('/content/val/val.csv')

train = train[train['species']!='na']
val = val[val['species']!='na']

train['img_path'] = train['img_path'].str.replace('/train','./train')
val['img_path'] = val['img_path'].str.replace('/content','./val')

In [6]:
label_encoders = dict()
for x in labels:
  le = LabelEncoder()
  train[f'{x}_label'] = le.fit_transform(train[x])
  val[f'{x}_label'] = le.transform(val[x])
  label_encoders[x] = le

In [7]:
train.to_csv('/content/train.csv',index = False)

val.to_csv('/content/val.csv',index = False)

# Dataset Objects

In [11]:
transformer = transforms.Compose([
        transforms.Resize((224,224)),
#         transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
    ])

In [None]:
class MultiDataset(Dataset):
  def __init__(self,dataFrame,label_encoders,transform=None):
    self.data = dataFrame
    self.transform = transform
    self.labels = ["phylum_label","class_label","species_label","form_label","sample_label"]
    self.label_encoder = label_encoders
    self.nclasses = self.get_nclasses()
    total = fsum(self.nclasses.values())

    self.weights = {k:v/total for k,v in self.nclasses.items()}

  def __len__(self):
    return len(self.data)

  def __getitem__(self, idx):
    img_path = self.data.loc[idx,"img_path"]
    image = Image.open(img_path)
    image = self.transform(image)
    labels = self.data.loc[idx,self.labels].to_dict()
    for label in labels:
      labels[label] = torch.tensor(labels[label],dtype = torch.long)
            
    dict_data = {
    'image': image,
    'labels': labels
    }
    return dict_data
  
  # def get_label_encode(self):
  #   d=dict()
  #   for label in self.labels:            
  #       le = LabelEncoder()
  #       self.data[label]=le.fit_transform(self.data[label])
  #       d[label]=le
  #   return d
    
  def get_nclasses(self):
    return self.data[self.labels].nunique().to_dict()

# Multi-label Classification

In [9]:
class MultiViT(nn.Module):
  def __init__(self,nclasses,labels):
    super(MultiViT,self).__init__()


    self.img_transformer = timm.models.vision_transformer.vit_base_patch16_224_in21k(pretrained=True,num_classes = 0)
    # self.img_transformer.load_state_dict(torch.load(vit_path))
    self.fc1 = nn.Linear(768,512)
    # for param in self.img_transformer.parameters():
    #   param.requires_grad = False
    self.drop = nn.Dropout(0.38)
    self.classifiers = nn.ModuleDict()
    for label in labels:
      self.classifiers[label]=nn.Sequential(nn.BatchNorm1d(512),
                                            nn.Linear(512,nclasses[label]))

    
  def forward(self,x):
    x=self.img_transformer(x)
    x = self.fc1(x)
    x=self.drop(x)
    y=dict()
    for label in self.classifiers.keys():
        y[label]=self.classifiers[label](x)
    # print(y.keys())
    return y


In [12]:
train = pd.read_csv("/content/train.csv")
val = pd.read_csv("/content/val.csv")

train_dataset = MultiDataset(train,label_encoders,val_transformer)
val_dataset = MultiDataset(val,label_encoders,val_transformer)

train_loader = DataLoader(train_dataset,batch_size=24)
val_loader = DataLoader(val_dataset,batch_size=17)

In [13]:
device = torch.device('cuda')
multimodel = MultiViT(train_dataset.nclasses,train_dataset.labels)
multimodel.to(device)
loss_fn = nn.CrossEntropyLoss()
optim = torch.optim.Adam(filter(lambda p: p.requires_grad, multimodel.parameters()),lr = 1e-4)

Removing representation layer for fine-tuning.
Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth" to /root/.cache/torch/hub/checkpoints/jx_vit_base_patch16_224_in21k-e5005f0a.pth


In [15]:
#@title Functions for training
def get_loss(out,labels,loss_fn,weights):
    loss=0
    for label in labels:
        loss+=loss_fn(out[label],labels[label])*weights[label]
    return loss


def cal_accuracy(out,labels,batch_size,epoch_acc):
    for key in out:
        q=(out[key].detach().argmax(axis=1)==labels[key]).sum().item()/batch_size
        if not key in epoch_acc.keys():
            epoch_acc[key]=q
        else:
            epoch_acc[key]+=q
    return epoch_acc


def get_avg_acc(epoch_acc,loader):
    acc=0
    for key in epoch_acc:
        epoch_acc[key]/=len(loader)
        acc+=epoch_acc[key]
        
    return acc/len(epoch_acc),epoch_acc

def evaluate(model,loss_fn,loader):
    model.eval()
    epoch_acc={}
    epoch_loss=0
    for batch in loader:
        img=batch["image"].to(device)
        labels=batch["labels"]
        for key in labels:
            labels[key]=labels[key].to(device)
        
        out=model(img)
        loss=get_loss(out,labels,loss_fn,loader.dataset.weights)
        epoch_loss+=loss.detach().item()
        epoch_acc=cal_accuracy(out,labels,img.shape[0],epoch_acc)
        
    avg_acc,epoch_acc=get_avg_acc(epoch_acc,loader)
    return (epoch_loss/len(loader),avg_acc,epoch_acc)

def train_model(model, loss_fn, opt,trainloader,valloader, num_epochs=1):
    since = time.time()
    best_model=""
    max_loss=np.inf
    train={"loss":[],"avg":[],"individual":[]}
    val={"loss":[],"avg":[],"individual":[]}
    for epoch in range(num_epochs):
        print('Epoch {} of {}'.format(epoch+1, num_epochs))
        epoch_loss=0
        epoch_acc={}
        model.train()
        # Iterate over data.
        for i,data in enumerate(trainloader):
            opt.zero_grad()
            img=data["image"].to(device)
            labels=data["labels"]
            for key in labels:
                labels[key]=labels[key].to(device)
            

            out=model(img)
            # print(out,labels,sep = '\n\n')
            loss=get_loss(out,labels,loss_fn,trainloader.dataset.weights)
            epoch_loss+=loss.detach().item()
            loss.backward()
            opt.step()
            epoch_acc=cal_accuracy(out,labels,img.shape[0],epoch_acc)
            
        val_result=evaluate(model,loss_fn,valloader)
        avg_acc,epoch_acc=get_avg_acc(epoch_acc,trainloader)
        avg_loss=epoch_loss/len(trainloader)
        print('Train')
        print(f"Loss: {avg_loss}\nAccuracy: Each_label_acc->{epoch_acc}\n\tAvg acc->{avg_acc}\n")
        print("Validation")
        print(f"Loss: {val_result[0]}\nAccuracy: Each_label_acc->{val_result[2]}\n\tAvg acc->{val_result[1]}")
        print("-"*50)
        
        train["avg"].append(avg_acc)
        train["individual"].append(epoch_acc)
        train["loss"].append(avg_loss)
        
        val["avg"].append(val_result[1])
        val["individual"].append(val_result[2])
        val["loss"].append(val_result[0])
        if val_result[0]<max_loss:
            max_loss=val_result[0]
            torch.save(model.state_dict(),'multimodal.bin')

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    return train,val,best_model,max_loss

In [18]:
train,val,best_model,best_acc=train_model(multimodel, loss_fn, optim,train_loader,val_loader,num_epochs=140)

Epoch 1 of 140
Train
Loss: 1.2894736528396606
Accuracy: Each_label_acc->{'phylum_label': 0.7064275814275814, 'class_label': 0.7069906444906444, 'species_label': 0.46881496881496865, 'form_label': 0.5692567567567567, 'sample_label': 0.7624740124740124}
	Avg acc->0.6427927927927927

Validation
Loss: 0.9163882732391357
Accuracy: Each_label_acc->{'phylum_label': 0.8235294117647058, 'class_label': 0.8655462184873949, 'species_label': 0.5546218487394958, 'form_label': 0.6050420168067226, 'sample_label': 0.8319327731092436}
	Avg acc->0.7361344537815127
--------------------------------------------------
Epoch 2 of 140
Train
Loss: 0.8497385998835435
Accuracy: Each_label_acc->{'phylum_label': 0.8210325710325709, 'class_label': 0.8554660429660428, 'species_label': 0.7212837837837834, 'form_label': 0.7393018018018016, 'sample_label': 0.8766891891891891}
	Avg acc->0.8027546777546777

Validation
Loss: 0.7857448628970555
Accuracy: Each_label_acc->{'phylum_label': 0.8739495798319329, 'class_label': 0.

In [115]:
multimodel = MultiViT(train_dataset.nclasses,train_dataset.labels)
multimodel.load_state_dict(torch.load('/content/multimodal.bin'))
multimodel.to(device)
preds, true = [],[]
val_loader = DataLoader(val_dataset,batch_size=17)
multimodel.eval()
with torch.no_grad():
  for batch in val_loader:
    img, labels = batch['image'], batch['labels']
    img = img.to(device)
    pred = multimodel(img)
    preds.append(pred)
    true.append(labels)


Removing representation layer for fine-tuning.


In [126]:
y_pred = {'phylum_label':[], 'class_label':[], 'species_label':[], 'form_label':[], 'sample_label':[]}
y_true = {'phylum_label':[], 'class_label':[], 'species_label':[], 'form_label':[], 'sample_label':[]}

In [127]:
for i in range(len(preds)):
  for key in y_pred:
    y_pred[key].extend(preds[i][key].argmax(axis = 1).cpu().tolist())
    y_true[key].extend(true[i][key].tolist())

In [133]:
from sklearn.metrics import classification_report, confusion_matrix
y = '_label'
for label in y_pred:
  label_le = label.replace(y,'')
  classes = val_loader.dataset.label_encoder[label_le].classes_
  print(f"\nClassification Report for {label}:\n")
  try:
    print(classification_report(y_true[label],y_pred[label],target_names=classes))
  except Exception as e:
    print(classification_report(y_true[label],y_pred[label]))
  print('='*95)


Classification Report for phylum_label:

                   precision    recall  f1-score   support

        Amoebozoa       0.80      0.89      0.84         9
      Apicomplexa       0.97      1.00      0.98        57
         Nematoda       0.89      0.97      0.93        33
  Platyhelminthes       0.78      0.64      0.70        11
Sarcomastigophora       1.00      0.56      0.71         9

         accuracy                           0.92       119
        macro avg       0.89      0.81      0.83       119
     weighted avg       0.92      0.92      0.91       119


Classification Report for class_label:

               precision    recall  f1-score   support

 Aconoidasida       1.00      0.97      0.99        40
      Cestoda       0.89      0.73      0.80        11
  Chromadorea       0.79      0.96      0.87        28
  Conoidasida       0.81      1.00      0.89        17
      Enoplea       1.00      0.40      0.57         5
    Tubulinea       0.75      0.67      0.71        

  _warn_prf(average, modifier, msg_start, len(result))
