<a href="https://colab.research.google.com/github/AmbiTyga/Bio-VI-BERT/blob/main/Multi-label-Classifier.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#@title EarlyStopping
class EarlyStopping(object):
    def __init__(self, mode='min', min_delta=0, patience=10, percentage=False):
        self.mode = mode
        self.min_delta = min_delta
        self.patience = patience
        self.best = None
        self.num_bad_epochs = 0
        self.is_better = None
        self._init_is_better(mode, min_delta, percentage)

        if patience == 0:
            self.is_better = lambda a, b: True
            self.step = lambda a: False

    def step(self, metrics):
        if self.best is None:
            self.best = metrics
            return False

        if torch.isnan(metrics):
            return True

        if self.is_better(metrics, self.best):
            self.num_bad_epochs = 0
            self.best = metrics
        else:
            self.num_bad_epochs += 1

        if self.num_bad_epochs >= self.patience:
            return True

        return False

    def _init_is_better(self, mode, min_delta, percentage):
        if mode not in {'min', 'max'}:
            raise ValueError('mode ' + mode + ' is unknown!')
        if not percentage:
            if mode == 'min':
                self.is_better = lambda a, best: a < best - min_delta
            if mode == 'max':
                self.is_better = lambda a, best: a > best + min_delta
        else:
            if mode == 'min':
                self.is_better = lambda a, best: a < best - (
                            best * min_delta / 100)
            if mode == 'max':
                self.is_better = lambda a, best: a > best + (
                            best * min_delta / 100)

In [None]:
!git clone https://github.com/rwightman/pytorch-image-models.git
%cd pytorch-image-models
!pip install -r requirements.txt -q
!python setup.py install -q
%cd ..

import os
os.kill(os.getpid(), 9)
# Restart Runtime

Cloning into 'pytorch-image-models'...
remote: Enumerating objects: 59, done.[K
remote: Counting objects: 100% (59/59), done.[K
remote: Compressing objects: 100% (49/49), done.[K
remote: Total 4632 (delta 21), reused 31 (delta 8), pack-reused 4573[K
Receiving objects: 100% (4632/4632), 15.82 MiB | 34.24 MiB/s, done.
Resolving deltas: 100% (3320/3320), done.
/content/pytorch-image-models
running install
running bdist_egg
running egg_info
creating timm.egg-info
writing timm.egg-info/PKG-INFO
writing dependency_links to timm.egg-info/dependency_links.txt
writing requirements to timm.egg-info/requires.txt
writing top-level names to timm.egg-info/top_level.txt
writing manifest file 'timm.egg-info/SOURCES.txt'
reading manifest template 'MANIFEST.in'
writing manifest file 'timm.egg-info/SOURCES.txt'
installing library code to build/bdist.linux-x86_64/egg
running install_lib
running build_py
creating build
creating build/lib
creating build/lib/timm
copying timm/version.py -> build/lib/timm

In [2]:
!wget https://raw.githubusercontent.com/AmbiTyga/Bio-VI-BERT/main/Dataset.7z
!7z x /content/Dataset.7z

--2021-03-10 08:50:13--  https://raw.githubusercontent.com/AmbiTyga/Bio-VI-BERT/main/Dataset.7z
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.110.133, 185.199.109.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 17231696 (16M) [application/octet-stream]
Saving to: ‘Dataset.7z’


2021-03-10 08:50:14 (61.3 MB/s) - ‘Dataset.7z’ saved [17231696/17231696]


7-Zip [64] 16.02 : Copyright (c) 1999-2016 Igor Pavlov : 2016-05-21
p7zip Version 16.02 (locale=en_US.UTF-8,Utf16=on,HugeFiles=on,64 bits,4 CPUs Intel(R) Xeon(R) CPU @ 2.20GHz (406F0),ASM,AES-NI)

Scanning the drive for archives:
  0M Scan /content/                   1 file, 17231696 bytes (17 MiB)

Extracting archive: /content/Dataset.7z
--
Path = /content/Dataset.7z
Type = 7z
Physical Size = 17231696
Headers Size = 6422
Method = LZMA2:24

In [3]:
import timm
import pandas as pd
import numpy as np
import torch.nn as nn
import torch
from torch.utils.data import Dataset, DataLoader, sampler
from sklearn.preprocessing import LabelEncoder
import os
from sklearn.model_selection import train_test_split
from glob import glob
from PIL import Image
from torchvision import transforms

## Cleaning and getting a final data

In [4]:
## Getting images(file path) from the directories 

imgs = []
for path, subdirs, files in os.walk('./Dataset'):
    for name in files:
        imgs.append(os.path.join(path, name))
imgs = [x for x in imgs if '.csv' not in x]

In [5]:
data = pd.read_csv("/content/Dataset/all_meta_data.csv")
data.head()

Unnamed: 0,phylum,class,genus,species,form,sample,image_name,image_url,img_path
0,Nematoda,Chromadorea,Enterobius,Enterobius vermicularis,egg,intestinal tissue,Evermicularis_worm4_HB.jpg,https://www.cdc.gov//dpdx/enterobiasis/images/...,./Dataset/Enterobius/Evermicularis_worm4_HB.jpg
1,Nematoda,Chromadorea,Enterobius,Enterobius vermicularis,egg,intestinal tissue,Evermicularis_egg_HBa.jpg,https://www.cdc.gov//dpdx/enterobiasis/images/...,./Dataset/Enterobius/Evermicularis_egg_HBa.jpg
2,Nematoda,Chromadorea,Enterobius,Enterobius vermicularis,egg,intestinal tissue,Evermicularis_egg_wtmt.jpg,https://www.cdc.gov//dpdx/enterobiasis/images/...,./Dataset/Enterobius/Evermicularis_egg_wtmt.jpg
3,Nematoda,Chromadorea,Enterobius,Enterobius vermicularis,egg,intestinal tissue,Evermicularis_SC_egg.jpg,https://www.cdc.gov//dpdx/enterobiasis/images/...,./Dataset/Enterobius/Evermicularis_SC_egg.jpg
4,Nematoda,Chromadorea,Enterobius,Enterobius vermicularis,egg,intestinal tissue,Evermicularis_egg_UVa.jpg,https://www.cdc.gov//dpdx/enterobiasis/images/...,./Dataset/Enterobius/Evermicularis_egg_UVa.jpg


In [6]:
import urllib.request as req

def download(url,file_name):
  imgurl =url
  req.urlretrieve(imgurl, file_name)

In [7]:
# Dropping unnecessary datapoints
def check_file_download(x):
  if x[1] not in imgs:
    download(*x)


data['image_path_url'] = data[['image_url','img_path']].values.tolist()
data['image_path_url'] = data['image_path_url'].apply(check_file_download)

In [8]:
data.to_csv('/content/Parasitesv1.csv',index = False)

In [17]:
data = [x for _,x in data.groupby('species') if len(x)>15]

data = pd.concat(data,ignore_index=True)

In [18]:
le = LabelEncoder()
data['species_label'] = le.fit_transform(data['species'])

In [19]:
train, val = train_test_split(
    data, test_size=0.25, random_state=2021,stratify = data['species'])

In [20]:
train.to_csv('/content/train.csv',index = False)

val.to_csv('/content/val.csv',index = False)

# Dataset Function

In [12]:
class SpeciesDataset(Dataset):
  def __init__(self,csv_file,transform):
    super().__init__()
    csv = pd.read_csv(csv_file)[['species','img_path']]
    labels = csv['species'].values

    self.images = csv['img_path'].values
    self.transform = transform

    self.LE = LabelEncoder()
    self.labels = self.LE.fit_transform(labels)    

  def __len__(self):
    # return size of dataset
    return len(self.images)

  def __getitem__(self, index):
    img = Image.open(self.images[index])
    img = self.transform(img)

    label = self.labels[index]

    return img, torch.LongTensor([label])

In [13]:

train_transformer = transforms.Compose([
        transforms.Resize((224,224)),
        transforms.RandomHorizontalFlip(),
        # transforms.
        # transforms.ColorJitter(hue=.05, saturation=.05),
        transforms.RandomRotation(90),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

val_transformer = transforms.Compose([
        transforms.Resize((224,224)),
#         transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

In [None]:
train

Unnamed: 0,phylum,class,genus,species,form,sample,image_name,image_url,img_path,image_path_url,species_label
344,Nematoda,Chromadorea,Strongyloides,Strongyloides stercoralis,adult,feces,OLYMPUS DIGITAL CAMERA,https://www.cdc.gov//dpdx/strongyloidiasis/ima...,./Dataset/Strongyloides/OLYMPUS DIGITAL CAMERA,,13
260,Apicomplexa,Aconoidasida,Plasmodium,Plasmodium ovale,trophozoites,blood,Po_troph_thinD.jpg,https://www.cdc.gov//dpdx/malaria/images/17/Po...,./Dataset/Plasmodium/Po_troph_thinD.jpg,,10
383,Nematoda,Enoplea,Trichuris,Trichuris Trichuria,adult,intestinal tissue,T_trichiura_CHA_A.jpg,https://www.cdc.gov//dpdx/trichuriasis/images/...,./Dataset/Trichuris/T_trichiura_CHA_A.jpg,,15
23,Nematoda,Chromadorea,Ascaris,Ascaris lumbricoides,egg,feces,Ascaris_egg_fert_embryo.jpg,https://www.cdc.gov//dpdx/ascariasis/images/2/...,./Dataset/Ascaris/Ascaris_egg_fert_embryo.jpg,,1
96,Platyhelminthes,Cestoda,Dibothriocephalus,Dibothriocephalus latus,egg,feces,Dlatum_egg_wtmt2.jpg,https://www.cdc.gov//dpdx/diphyllobothriasis/i...,./Dataset/Dibothriocephalus/Dlatum_egg_wtmt2.jpg,,4
...,...,...,...,...,...,...,...,...,...,...,...
174,Sarcomastigophora,Zooflagellate,Giardia,Giardia duodenalis,cyst,wet mount,Giardia_cyst_tric4.jpg,https://www.cdc.gov//dpdx/giardiasis/images/3/...,./Dataset/Giardia/Giardia_cyst_tric4.jpg,,7
151,Nematoda,Chromadorea,Enterobius,Enterobius vermicularis,adult,intestinal tissue,Evermicularis_SC_posterior.jpg,https://www.cdc.gov//dpdx/enterobiasis/images/...,./Dataset/Enterobius/Evermicularis_SC_posterio...,,6
68,Apicomplexa,Conoidasida,Cyclospora,Cyclospora cayetanensis,oocysts,feces,Cyclospora_UV_Henry1.jpg,https://www.cdc.gov//dpdx/cyclosporiasis/image...,./Dataset/Cyclospora/Cyclospora_UV_Henry1.jpg,,3
190,Apicomplexa,Aconoidasida,Plasmodium,Plasmodium falciparum,rings,blood,Pf_rings_thickE.jpg,https://www.cdc.gov//dpdx/malaria/images/1/Pf_...,./Dataset/Plasmodium/Pf_rings_thickE.jpg,,8


In [21]:
val = pd.read_csv('/content/val.csv')
class_counts = val['species_label'].value_counts().values
weights = torch.tensor([1/class_counts[label] for label in val['species_label'].values])

In [22]:
train = pd.read_csv('/content/train.csv')
class_counts = train['species_label'].value_counts().values
weights = torch.tensor([1/class_counts[label] for label in train['species_label'].values])

train_dataset = SpeciesDataset('/content/train.csv',transform=train_transformer)
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=48,
                                           sampler=torch.utils.data.WeightedRandomSampler(weights,num_samples=496))

val = pd.read_csv('/content/val.csv')
class_counts = val['species_label'].value_counts().values
weights = torch.tensor([1/class_counts[label] for label in val['species_label'].values])

val_dataset = SpeciesDataset('/content/val.csv',transform=val_transformer)
val_loader = torch.utils.data.DataLoader(val_dataset,batch_size=17,
                                         sampler=torch.utils.data.WeightedRandomSampler(weights,num_samples=80))

# VIT Base Imagenet 21k

In [34]:
class ViT(nn.Module):
  def __init__(self,num_classes_classifier=16):
    super().__init__()
    self.img_transformer = timm.models.vision_transformer.vit_base_patch16_224_in21k(pretrained=True,num_classes = 512)
    self.drop = nn.Dropout(0.4)
    self.classifier = nn.Linear(512,num_classes_classifier)

  def forward(self,img):
    repr = self.img_transformer(img)
    repr = self.drop(repr)
    output = self.classifier(repr)

    return output  

In [35]:
model = ViT()

Removing representation layer for fine-tuning.
Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth" to /root/.cache/torch/hub/checkpoints/jx_vit_base_patch16_224_in21k-e5005f0a.pth


In [None]:
model.load_state_dict(torch.load('vit-base.bin'))

<All keys matched successfully>

In [36]:
device = torch.device('cuda')
optim = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),lr = 1e-4)
loss_fn = nn.CrossEntropyLoss().to(device)
es = EarlyStopping(patience = 10)
es.best = np.inf

model.cuda()

train_losses = []
train_accs = []
val_losses = []
val_accs = []
for epoch in range(100):
  train_epoch_loss = 0
  train_epoch_acc = 0
  model.train()
  steps = 0
  for idx, batch in enumerate(train_loader):
    image, target = batch
    image = image.to(device)
    target = target.flatten().to(device)

    optim.zero_grad()

    prediction = model(image)
    # print(prediction,target)
    loss = loss_fn(prediction, target)

    num_corrects = (prediction.argmax(dim=1).view(target.size()).data == target.data).sum()
    acc = 100.0 * (num_corrects/target.size(0))
    loss.backward()
    optim.step()
    steps += 1
    
    # if steps % 6 == 0:
    #     print (f'\t Train - Epoch: {epoch+1}, Idx: {idx+1}, Training Loss: {loss.item():.4f}, Training Accuracy: {acc.item(): .2f}%')
    
    train_epoch_loss += loss.item()
    train_epoch_acc += acc.item()

  val_epoch_loss = 0
  val_epoch_acc = 0
  model.eval()
  val_steps = 0
  with torch.no_grad():
    for idx, batch in enumerate(val_loader):
      image, target = batch
      image = image.to(device)
      target = target.flatten().to(device)

      prediction = model(image)
      loss = loss_fn(prediction, target)
      
      num_corrects = (prediction.argmax(dim=1).view(target.size()).data == target.data).sum()
      acc = 100.0 * (num_corrects/target.size(0))

      val_epoch_loss += loss.item()
      val_epoch_acc += acc.item()
      val_steps +=1

  
  val_epoch_loss /=val_steps
  val_epoch_acc /=val_steps

  # val_losses.append(val_epoch_loss)
  # val_accs.append(val_epoch_acc)
  
  print(f"Epoch: {epoch+1:02}, Train Loss: {train_epoch_loss/steps:.3f}, Train Acc: {train_epoch_acc/steps:.2f}%, Val. Loss: {val_epoch_loss:.3f}, Val. Acc: {val_epoch_acc:.2f}% \n{90*'='}")
  
  if val_epoch_loss < es.best:
    best_loss = val_epoch_loss
    print(f'\r\t{es.best:.3f} --> {best_loss:.3f}  Updating')
    torch.save(model.state_dict(),'vit-base.bin')

    train_losses.append(train_epoch_loss/steps)
    train_accs.append(train_epoch_acc/steps)

    val_losses.append(val_epoch_loss)
    val_accs.append(val_epoch_acc)
    
  if es.step(torch.tensor(val_epoch_loss)):
    print(f'\r\tPatience complete! Best Loss is {es.best:.3f}')
    break
  if val_epoch_loss > best_loss:
    print(f'\r\t Patience -> {es.patience - es.num_bad_epochs}')

Epoch: 01, Train Loss: 2.582, Train Acc: 24.43%, Val. Loss: 2.430, Val. Acc: 35.59% 
	inf --> 2.430  Updating
Epoch: 02, Train Loss: 2.276, Train Acc: 39.39%, Val. Loss: 2.234, Val. Acc: 35.59% 
	2.430 --> 2.234  Updating
Epoch: 03, Train Loss: 1.959, Train Acc: 56.25%, Val. Loss: 1.949, Val. Acc: 48.14% 
	2.234 --> 1.949  Updating
Epoch: 04, Train Loss: 1.675, Train Acc: 68.37%, Val. Loss: 1.718, Val. Acc: 57.55% 
	1.949 --> 1.718  Updating
Epoch: 05, Train Loss: 1.359, Train Acc: 76.70%, Val. Loss: 1.437, Val. Acc: 63.43% 
	1.718 --> 1.437  Updating
Epoch: 06, Train Loss: 1.067, Train Acc: 83.71%, Val. Loss: 1.451, Val. Acc: 55.88% 
	 Patience -> 9
Epoch: 07, Train Loss: 0.827, Train Acc: 91.48%, Val. Loss: 1.189, Val. Acc: 67.25% 
	1.437 --> 1.189  Updating
Epoch: 08, Train Loss: 0.639, Train Acc: 94.32%, Val. Loss: 1.202, Val. Acc: 61.57% 
	 Patience -> 9
Epoch: 09, Train Loss: 0.490, Train Acc: 97.16%, Val. Loss: 0.935, Val. Acc: 79.02% 
	1.189 --> 0.935  Updating
Epoch: 10, Train

# Multi-label Classification

In [52]:
torch.save(model.img_transformer.state_dict(),'vit-main.bin')

In [44]:
class MultiDataset(Dataset):
  def __init__(self,dataFrame,transform=None):
    self.data = dataFrame
    self.transform = transform
    self.labels=["phylum","class","species","form","sample"]
    self.label_encoder=self.get_label_encode()
    self.nclasses=self.get_nclasses()

  def __len__(self):
    return len(self.data)

  def __getitem__(self, idx):
    img_path = self.meta_data.loc[idx,"img_path"]
    image = Image.open(self.images[idx])
    image = self.transform(image)
    labels = self.data.loc[idx,self.labels].to_dict()
    for label in labels:
      labels[label] = torch.tensor(labels[label],dtype=torch.float)

    image = self.transform(image)
            
    dict_data = {
    'image': image,
    'labels': labels
    }
    return dict_data
  
  def get_label_encode(self):
    d=dict()
    for label in self.labels:            
        le = LabelEncoder()
        self.data[label]=le.fit_transform(self.data[label])
        d[label]=le
    return d
  
  def get_nclasses(self):
    d=dict()
    for label in self.label_encoder:
        d[label]=len(self.label_encoder[label].classes_)
    return d

In [57]:
class MultiViT(nn.Module):
  def __init__(self,vit_path,nclasses,labels):
    super(MultiViT,self).__init__()


    self.img_transformer = timm.models.vision_transformer.vit_base_patch16_224_in21k(pretrained=False,num_classes = 512).load_state_dict(
        torch.load(vit_path)
    )
    self.drop = nn.Dropout(0.38)
    self.classifiers = nn.ModuleDict()
    for label in labels:
      self.classifiers[label]=nn.Linear(512,nclasses[label])

    
  def forward(self,x):
    x=self.img_transformer(x)
    x=self.drop(x)
    y=dict()
    for label in self.classifiers.keys():
        y[label]=self.classifiers[label](x)
    return y


In [58]:
train = pd.read_csv("/content/train.csv")
val = pd.read_csv("/content/val.csv")

train_dataset = MultiDataset(train,train_transformer)
val_dataset = MultiDataset(val,val_transformer)

train_loader = DataLoader(train_dataset,batch_size=32)
val_loader = DataLoader(val_dataset,batch_size=32)

In [66]:
multimodel = MultiViT('/content/vit-main.bin',train_dataset.nclasses,train_dataset.labels)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(multimodel.parameters(), lr=1e-4)

In [None]:
def train_model(model, loss_fn, opt,trainloader,valloader, num_epochs=1):
    since = time.time()
    best_model=""
    max_acc=-999
    train={"loss":[],"avg":[],"individual":[]}
    val={"loss":[],"avg":[],"individual":[]}
    for epoch in range(num_epochs):
        print('Epoch {} of {}'.format(epoch+1, num_epochs))
        epoch_loss=0
        epoch_acc={}
        model.train()
        # Iterate over data.
        for i,data in enumerate(trainloader):
            opt.zero_grad()
            img=data["image"].to(device)
            labels=data["labels"]
            for key in labels:
                labels[key]=labels[key].to(device)
            

            out=model(img)
            loss=get_loss(out,labels,loss_fn)
            epoch_loss+=loss.detach().item()
            loss.backward()
            opt.step()
            epoch_acc=cal_accuracy(out,labels,img.shape[0],epoch_acc)
            
        val_result=evaluate(model,loss_fn,valloader)
        avg_acc,epoch_acc=get_avg_acc(epoch_acc,trainloader)
        avg_loss=epoch_loss/len(trainloader)
        print('Train')
        print(f"Loss: {avg_loss}\nAccuracy: Each_label_acc->{epoch_acc}\n\tAvg acc->{avg_acc}\n")
        print("Validation")
        print(f"Loss: {val_result[0]}\nAccuracy: Each_label_acc->{val_result[2]}\n\tAvg acc->{val_result[1]}")
        print("-"*50)
        
        train["avg"].append(avg_acc)
        train["individual"].append(epoch_acc)
        train["loss"].append(avg_loss)
        
        val["avg"].append(val_result[1])
        val["individual"].append(val_result[2])
        val["loss"].append(val_result[0])
        if val_result[1]>max_acc:
            max_acc=val_result[1]
            best_model=copy.deepcopy(model)

        
       

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    return train,val,best_model,max_acc