<a href="https://colab.research.google.com/github/AmbiTyga/Bio-VI-BERT/blob/main/Multi-label-Classifier.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#@title EarlyStopping
class EarlyStopping(object):
    def __init__(self, mode='min', min_delta=0, patience=10, percentage=False):
        self.mode = mode
        self.min_delta = min_delta
        self.patience = patience
        self.best = None
        self.num_bad_epochs = 0
        self.is_better = None
        self._init_is_better(mode, min_delta, percentage)

        if patience == 0:
            self.is_better = lambda a, b: True
            self.step = lambda a: False

    def step(self, metrics):
        if self.best is None:
            self.best = metrics
            return False

        if torch.isnan(metrics):
            return True

        if self.is_better(metrics, self.best):
            self.num_bad_epochs = 0
            self.best = metrics
        else:
            self.num_bad_epochs += 1

        if self.num_bad_epochs >= self.patience:
            return True

        return False

    def _init_is_better(self, mode, min_delta, percentage):
        if mode not in {'min', 'max'}:
            raise ValueError('mode ' + mode + ' is unknown!')
        if not percentage:
            if mode == 'min':
                self.is_better = lambda a, best: a < best - min_delta
            if mode == 'max':
                self.is_better = lambda a, best: a > best + min_delta
        else:
            if mode == 'min':
                self.is_better = lambda a, best: a < best - (
                            best * min_delta / 100)
            if mode == 'max':
                self.is_better = lambda a, best: a > best + (
                            best * min_delta / 100)

In [None]:
!git clone https://github.com/rwightman/pytorch-image-models.git
%cd pytorch-image-models
!pip install -r requirements.txt -q
!python setup.py install -q
%cd ..

import os
os.kill(os.getpid(), 9)
# Restart Runtime

Cloning into 'pytorch-image-models'...
remote: Enumerating objects: 85, done.[K
remote: Counting objects: 100% (85/85), done.[K
remote: Compressing objects: 100% (65/65), done.[K
remote: Total 4658 (delta 29), reused 47 (delta 12), pack-reused 4573[K
Receiving objects: 100% (4658/4658), 15.88 MiB | 27.84 MiB/s, done.
Resolving deltas: 100% (3328/3328), done.
/content/pytorch-image-models
running install
running bdist_egg
running egg_info
creating timm.egg-info
writing timm.egg-info/PKG-INFO
writing dependency_links to timm.egg-info/dependency_links.txt
writing requirements to timm.egg-info/requires.txt
writing top-level names to timm.egg-info/top_level.txt
writing manifest file 'timm.egg-info/SOURCES.txt'
reading manifest template 'MANIFEST.in'
writing manifest file 'timm.egg-info/SOURCES.txt'
installing library code to build/bdist.linux-x86_64/egg
running install_lib
running build_py
creating build
creating build/lib
creating build/lib/timm
copying timm/__init__.py -> build/lib/ti

In [2]:
!wget https://raw.githubusercontent.com/AmbiTyga/Bio-VI-BERT/main/Dataset.7z
!7z x /content/Dataset.7z

--2021-03-10 19:37:56--  https://raw.githubusercontent.com/AmbiTyga/Bio-VI-BERT/main/Dataset.7z
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.110.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 17231696 (16M) [application/octet-stream]
Saving to: ‘Dataset.7z’


2021-03-10 19:37:57 (68.8 MB/s) - ‘Dataset.7z’ saved [17231696/17231696]


7-Zip [64] 16.02 : Copyright (c) 1999-2016 Igor Pavlov : 2016-05-21
p7zip Version 16.02 (locale=en_US.UTF-8,Utf16=on,HugeFiles=on,64 bits,4 CPUs Intel(R) Xeon(R) CPU @ 2.20GHz (406F0),ASM,AES-NI)

Scanning the drive for archives:
  0M Scan /content/                   1 file, 17231696 bytes (17 MiB)

Extracting archive: /content/Dataset.7z
--
Path = /content/Dataset.7z
Type = 7z
Physical Size = 17231696
Headers Size = 6422
Method = LZMA2:24

In [3]:
import timm
import pandas as pd
import numpy as np
import torch.nn as nn
import torch
from torch.utils.data import Dataset, DataLoader, sampler
from sklearn.preprocessing import LabelEncoder
import os,time
from math import fsum
from sklearn.model_selection import train_test_split
from glob import glob
from PIL import Image
from torchvision import transforms

## Cleaning and getting a final data

In [4]:
## Getting images(file path) from the directories 

imgs = []
for path, subdirs, files in os.walk('./Dataset'):
    for name in files:
        imgs.append(os.path.join(path, name))
imgs = [x for x in imgs if '.csv' not in x]

In [5]:
data = pd.read_csv("/content/Dataset/all_meta_data.csv")
data.head()

Unnamed: 0,phylum,class,genus,species,form,sample,image_name,image_url,img_path
0,Nematoda,Chromadorea,Enterobius,Enterobius vermicularis,egg,intestinal tissue,Evermicularis_worm4_HB.jpg,https://www.cdc.gov//dpdx/enterobiasis/images/...,./Dataset/Enterobius/Evermicularis_worm4_HB.jpg
1,Nematoda,Chromadorea,Enterobius,Enterobius vermicularis,egg,intestinal tissue,Evermicularis_egg_HBa.jpg,https://www.cdc.gov//dpdx/enterobiasis/images/...,./Dataset/Enterobius/Evermicularis_egg_HBa.jpg
2,Nematoda,Chromadorea,Enterobius,Enterobius vermicularis,egg,intestinal tissue,Evermicularis_egg_wtmt.jpg,https://www.cdc.gov//dpdx/enterobiasis/images/...,./Dataset/Enterobius/Evermicularis_egg_wtmt.jpg
3,Nematoda,Chromadorea,Enterobius,Enterobius vermicularis,egg,intestinal tissue,Evermicularis_SC_egg.jpg,https://www.cdc.gov//dpdx/enterobiasis/images/...,./Dataset/Enterobius/Evermicularis_SC_egg.jpg
4,Nematoda,Chromadorea,Enterobius,Enterobius vermicularis,egg,intestinal tissue,Evermicularis_egg_UVa.jpg,https://www.cdc.gov//dpdx/enterobiasis/images/...,./Dataset/Enterobius/Evermicularis_egg_UVa.jpg


In [6]:
import urllib.request as req

def download(url,file_name):
  imgurl =url
  req.urlretrieve(imgurl, file_name)

In [7]:
# Dropping unnecessary datapoints
def check_file_download(x):
  if x[1] not in imgs:
    download(*x)


data['image_path_url'] = data[['image_url','img_path']].values.tolist()
data['image_path_url'] = data['image_path_url'].apply(check_file_download)

In [8]:
data.to_csv('/content/Parasitesv1.csv',index = False)

In [9]:
data = [x for _,x in data.groupby('species') if len(x)>15]

data = pd.concat(data,ignore_index=True)

In [10]:
labels = ["phylum","class","species","form","sample"]

In [11]:
label_encoders = dict()
for x in labels:
  le = LabelEncoder()
  data[f'{x}_label'] = le.fit_transform(data[x])
  label_encoders[x] = le

In [12]:
train, val = train_test_split(
    data, test_size=0.13, random_state=2021,stratify = data['species'])

In [13]:
train.to_csv('/content/train.csv',index = False)

val.to_csv('/content/val.csv',index = False)

# Dataset Function

In [14]:
class SpeciesDataset(Dataset):
  def __init__(self,csv_file,transform):
    super().__init__()
    csv = pd.read_csv(csv_file)[['species_label','img_path']]
    self.labels = csv['species_label'].values

    self.images = csv['img_path'].values
    self.transform = transform

    # self.LE = LabelEncoder()
    # self.labels = self.LE.fit_transform(labels)    

  def __len__(self):
    # return size of dataset
    return len(self.images)

  def __getitem__(self, index):
    img = Image.open(self.images[index])
    img = self.transform(img)

    label = self.labels[index]

    return img, torch.LongTensor([label])

In [15]:

train_transformer = transforms.Compose([
        transforms.Resize((224,224)),
        transforms.RandomHorizontalFlip(),
        # transforms.
        # transforms.ColorJitter(hue=.05, saturation=.05),
        transforms.RandomRotation(90),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

val_transformer = transforms.Compose([
        transforms.Resize((224,224)),
#         transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

In [113]:
31*16

496

In [16]:
train = pd.read_csv('/content/train.csv')
class_counts = train['species_label'].value_counts().values
weights = torch.tensor([1/class_counts[label] for label in train['species_label'].values])

train_dataset = SpeciesDataset('/content/train.csv',transform=train_transformer)
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=48,
                                           sampler=torch.utils.data.WeightedRandomSampler(weights,num_samples=496))

val = pd.read_csv('/content/val.csv')
class_counts = val['species_label'].value_counts().values
weights = torch.tensor([1/class_counts[label] for label in val['species_label'].values])

val_dataset = SpeciesDataset('/content/val.csv',transform=val_transformer)
val_loader = torch.utils.data.DataLoader(val_dataset,batch_size=17,
                                         sampler=torch.utils.data.WeightedRandomSampler(weights,num_samples=80))

# VIT Base Imagenet 21k

In [17]:
class ViT(nn.Module):
  def __init__(self,num_classes_classifier=16):
    super().__init__()
    self.img_transformer = timm.models.vision_transformer.vit_base_patch16_224_in21k(pretrained=True,num_classes = 512)
    self.drop = nn.Dropout(0.4)
    self.classifier = nn.Linear(512,num_classes_classifier)

  def forward(self,img):
    repr = self.img_transformer(img)
    repr = self.drop(repr)
    output = self.classifier(repr)

    return output  

In [18]:
model = ViT()

Removing representation layer for fine-tuning.
Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth" to /root/.cache/torch/hub/checkpoints/jx_vit_base_patch16_224_in21k-e5005f0a.pth


In [19]:
device = torch.device('cuda')
optim = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),lr = 1e-4)
loss_fn = nn.CrossEntropyLoss().to(device)
es = EarlyStopping(patience = 10)
es.best = np.inf

model.cuda()

train_losses = []
train_accs = []
val_losses = []
val_accs = []
for epoch in range(100):
  train_epoch_loss = 0
  train_epoch_acc = 0
  model.train()
  steps = 0
  for idx, batch in enumerate(train_loader):
    image, target = batch
    image = image.to(device)
    target = target.flatten().to(device)

    optim.zero_grad()

    prediction = model(image)
    # print(prediction,target)
    loss = loss_fn(prediction, target)

    num_corrects = (prediction.argmax(dim=1).view(target.size()).data == target.data).sum()
    acc = 100.0 * (num_corrects/target.size(0))
    loss.backward()
    optim.step()
    steps += 1
    
    # if steps % 6 == 0:
    #     print (f'\t Train - Epoch: {epoch+1}, Idx: {idx+1}, Training Loss: {loss.item():.4f}, Training Accuracy: {acc.item(): .2f}%')
    
    train_epoch_loss += loss.item()
    train_epoch_acc += acc.item()

  val_epoch_loss = 0
  val_epoch_acc = 0
  model.eval()
  val_steps = 0
  with torch.no_grad():
    for idx, batch in enumerate(val_loader):
      image, target = batch
      image = image.to(device)
      target = target.flatten().to(device)

      prediction = model(image)
      loss = loss_fn(prediction, target)
      
      num_corrects = (prediction.argmax(dim=1).view(target.size()).data == target.data).sum()
      acc = 100.0 * (num_corrects/target.size(0))

      val_epoch_loss += loss.item()
      val_epoch_acc += acc.item()
      val_steps +=1

  
  val_epoch_loss /=val_steps
  val_epoch_acc /=val_steps

  # val_losses.append(val_epoch_loss)
  # val_accs.append(val_epoch_acc)
  
  print(f"Epoch: {epoch+1:02}, Train Loss: {train_epoch_loss/steps:.3f}, Train Acc: {train_epoch_acc/steps:.2f}%, Val. Loss: {val_epoch_loss:.3f}, Val. Acc: {val_epoch_acc:.2f}% \n{90*'='}")
  
  if val_epoch_loss < es.best:
    best_loss = val_epoch_loss
    print(f'\r\t{es.best:.3f} --> {best_loss:.3f}  Updating')
    torch.save(model.img_transformer.state_dict(),'vit-base.bin')

    train_losses.append(train_epoch_loss/steps)
    train_accs.append(train_epoch_acc/steps)

    val_losses.append(val_epoch_loss)
    val_accs.append(val_epoch_acc)
    
  if es.step(torch.tensor(val_epoch_loss)):
    print(f'\r\tPatience complete! Best Loss is {es.best:.3f}')
    break
  if val_epoch_loss > best_loss:
    print(f'\r\t Patience -> {es.patience - es.num_bad_epochs}')

Epoch: 01, Train Loss: 2.609, Train Acc: 27.27%, Val. Loss: 2.519, Val. Acc: 19.80% 
	inf --> 2.519  Updating
Epoch: 02, Train Loss: 2.295, Train Acc: 43.56%, Val. Loss: 2.319, Val. Acc: 25.69% 
	2.519 --> 2.319  Updating
Epoch: 03, Train Loss: 2.003, Train Acc: 51.52%, Val. Loss: 2.093, Val. Acc: 39.90% 
	2.319 --> 2.093  Updating
Epoch: 04, Train Loss: 1.740, Train Acc: 63.26%, Val. Loss: 1.844, Val. Acc: 54.02% 
	2.093 --> 1.844  Updating
Epoch: 05, Train Loss: 1.476, Train Acc: 76.52%, Val. Loss: 1.726, Val. Acc: 50.00% 
	1.844 --> 1.726  Updating
Epoch: 06, Train Loss: 1.187, Train Acc: 84.66%, Val. Loss: 1.543, Val. Acc: 49.51% 
	1.726 --> 1.543  Updating
Epoch: 07, Train Loss: 0.889, Train Acc: 90.34%, Val. Loss: 1.360, Val. Acc: 65.10% 
	1.543 --> 1.360  Updating
Epoch: 08, Train Loss: 0.647, Train Acc: 96.40%, Val. Loss: 1.058, Val. Acc: 76.67% 
	1.360 --> 1.058  Updating
Epoch: 09, Train Loss: 0.447, Train Acc: 98.48%, Val. Loss: 1.072, Val. Acc: 71.47% 
	 Patience -> 9
Epoch

# Multi-label Classification

In [28]:
data

Unnamed: 0,phylum,class,genus,species,form,sample,image_name,image_url,img_path,image_path_url,phylum_label,class_label,species_label,form_label,sample_label
0,Nematoda,Chromadorea,Ancylostoma,Ancyclostoma sp.,egg,feces,Hookworm_egg_wtmt.jpg,https://www.cdc.gov//dpdx/hookworm/images/1/Ho...,./Dataset/Ancylostoma/Hookworm_egg_wtmt.jpg,,2,2,0,3,1
1,Nematoda,Chromadorea,Ancylostoma,Ancyclostoma sp.,egg,feces,Hookworm_egg_BAM1.jpg,https://www.cdc.gov//dpdx/hookworm/images/1/Ho...,./Dataset/Ancylostoma/Hookworm_egg_BAM1.jpg,,2,2,0,3,1
2,Nematoda,Chromadorea,Ancylostoma,Ancyclostoma sp.,egg,feces,Hookworm_egg_BAM_MCS.jpg,https://www.cdc.gov//dpdx/hookworm/images/1/Ho...,./Dataset/Ancylostoma/Hookworm_egg_BAM_MCS.jpg,,2,2,0,3,1
3,Nematoda,Chromadorea,Ancylostoma,Ancyclostoma sp.,egg,feces,Hookworm_2x2_B.jpg,https://www.cdc.gov//dpdx/hookworm/images/1/Ho...,./Dataset/Ancylostoma/Hookworm_2x2_B.jpg,,2,2,0,3,1
4,Nematoda,Chromadorea,Ancylostoma,Ancyclostoma sp.,egg,feces,Hookworm_2x2_C.jpg,https://www.cdc.gov//dpdx/hookworm/images/1/Ho...,./Dataset/Ancylostoma/Hookworm_2x2_C.jpg,,2,2,0,3,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
383,Nematoda,Enoplea,Trichuris,Trichuris Trichuria,adult,intestinal tissue,T_trichiura_CHA_A.jpg,https://www.cdc.gov//dpdx/trichuriasis/images/...,./Dataset/Trichuris/T_trichiura_CHA_A.jpg,,2,4,15,0,2
384,Nematoda,Enoplea,Trichuris,Trichuris Trichuria,adult,intestinal tissue,Trichuris_trichiura_adult_Duke.jpg,https://www.cdc.gov//dpdx/trichuriasis/images/...,./Dataset/Trichuris/Trichuris_trichiura_adult_...,,2,4,15,0,2
385,Nematoda,Enoplea,Trichuris,Trichuris Trichuria,adult,intestinal tissue,T_trichiura_adult_HB1.jpg,https://www.cdc.gov//dpdx/trichuriasis/images/...,./Dataset/Trichuris/T_trichiura_adult_HB1.jpg,,2,4,15,0,2
386,Nematoda,Enoplea,Trichuris,Trichuris Trichuria,adult,intestinal tissue,T_trichiura_adult_HB2.jpg,https://www.cdc.gov//dpdx/trichuriasis/images/...,./Dataset/Trichuris/T_trichiura_adult_HB2.jpg,,2,4,15,0,2


In [20]:
class MultiDataset(Dataset):
  def __init__(self,dataFrame,label_encoders,transform=None):
    self.data = dataFrame
    self.transform = transform
    self.labels = ["phylum_label","class_label","species_label","form_label","sample_label"]
    self.label_encoder = label_encoders
    self.nclasses = self.get_nclasses()
    total = fsum(self.nclasses.values())

    self.weights = {k:v/total for k,v in self.nclasses.items()}

  def __len__(self):
    return len(self.data)

  def __getitem__(self, idx):
    img_path = self.data.loc[idx,"img_path"]
    image = Image.open(img_path)
    image = self.transform(image)
    labels = self.data.loc[idx,self.labels].to_dict()
    for label in labels:
      labels[label] = torch.tensor(labels[label],dtype = torch.long)
            
    dict_data = {
    'image': image,
    'labels': labels
    }
    return dict_data
  
  # def get_label_encode(self):
  #   d=dict()
  #   for label in self.labels:            
  #       le = LabelEncoder()
  #       self.data[label]=le.fit_transform(self.data[label])
  #       d[label]=le
  #   return d
    
  def get_nclasses(self):
    return self.data[self.labels].nunique().to_dict()

In [21]:
class MultiViT(nn.Module):
  def __init__(self,vit_path,nclasses,labels):
    super(MultiViT,self).__init__()


    self.img_transformer = timm.models.vision_transformer.vit_base_patch16_224_in21k(pretrained=False,num_classes = 512)
    self.img_transformer.load_state_dict(torch.load(vit_path))

    for param in self.img_transformer.parameters():
      param.requires_grad = False
    self.drop = nn.Dropout(0.38)
    self.classifiers = nn.ModuleDict()
    for label in labels:
      self.classifiers[label]=nn.Linear(512,nclasses[label])

    
  def forward(self,x):
    x=self.img_transformer(x)
    x=self.drop(x)
    y=dict()
    for label in self.classifiers.keys():
        y[label]=self.classifiers[label](x)
    # print(y.keys())
    return y


In [29]:
train = pd.read_csv("/content/train.csv")
val = pd.read_csv("/content/val.csv")

train_dataset = MultiDataset(train,label_encoders,train_transformer)
val_dataset = MultiDataset(val,label_encoders,val_transformer)

train_loader = DataLoader(train_dataset,batch_size=32)
val_loader = DataLoader(val_dataset,batch_size=32)

In [23]:
multimodel = MultiViT('/content/vit-base.bin',train_dataset.nclasses,train_dataset.labels)
multimodel.to(device)
loss_fn = nn.CrossEntropyLoss()
optim = torch.optim.Adam(filter(lambda p: p.requires_grad, multimodel.parameters()),lr = 1e-4)

Removing representation layer for fine-tuning.


In [27]:
train_loader.dataset.weights

{'class': 0.14, 'form': 0.24, 'phylum': 0.1, 'sample': 0.2, 'species': 0.32}

{'class': 0.14, 'form': 0.24, 'phylum': 0.1, 'sample': 0.2, 'species': 0.32}

In [25]:
#@title Functions for training
def get_loss(out,labels,loss_fn,weights):
    loss=0
    for label in labels:
        loss+=loss_fn(out[label],labels[label])*weights[label]
    return loss


def cal_accuracy(out,labels,batch_size,epoch_acc):
    for key in out:
        q=(out[key].detach().argmax(axis=1)==labels[key]).sum().item()/batch_size
        if not key in epoch_acc.keys():
            epoch_acc[key]=q
        else:
            epoch_acc[key]+=q
    return epoch_acc


def get_avg_acc(epoch_acc,loader):
    acc=0
    for key in epoch_acc:
        epoch_acc[key]/=len(loader)
        acc+=epoch_acc[key]
        
    return acc/len(epoch_acc),epoch_acc

def evaluate(model,loss_fn,loader):
    model.eval()
    epoch_acc={}
    epoch_loss=0
    for batch in loader:
        img=batch["image"].to(device)
        labels=batch["labels"]
        for key in labels:
            labels[key]=labels[key].to(device)
        
        out=model(img)
        loss=get_loss(out,labels,loss_fn,loader.dataset.weights)
        epoch_loss+=loss.detach().item()
        epoch_acc=cal_accuracy(out,labels,img.shape[0],epoch_acc)
        
    avg_acc,epoch_acc=get_avg_acc(epoch_acc,loader)
    return (epoch_loss/len(loader),avg_acc,epoch_acc)

def train_model(model, loss_fn, opt,trainloader,valloader, num_epochs=1):
    since = time.time()
    best_model=""
    max_loss=np.inf
    train={"loss":[],"avg":[],"individual":[]}
    val={"loss":[],"avg":[],"individual":[]}
    for epoch in range(num_epochs):
        print('Epoch {} of {}'.format(epoch+1, num_epochs))
        epoch_loss=0
        epoch_acc={}
        model.train()
        # Iterate over data.
        for i,data in enumerate(trainloader):
            opt.zero_grad()
            img=data["image"].to(device)
            labels=data["labels"]
            for key in labels:
                labels[key]=labels[key].to(device)
            

            out=model(img)
            # print(out,labels,sep = '\n\n')
            loss=get_loss(out,labels,loss_fn,trainloader.dataset.weights)
            epoch_loss+=loss.detach().item()
            loss.backward()
            opt.step()
            epoch_acc=cal_accuracy(out,labels,img.shape[0],epoch_acc)
            
        val_result=evaluate(model,loss_fn,valloader)
        avg_acc,epoch_acc=get_avg_acc(epoch_acc,trainloader)
        avg_loss=epoch_loss/len(trainloader)
        print('Train')
        print(f"Loss: {avg_loss}\nAccuracy: Each_label_acc->{epoch_acc}\n\tAvg acc->{avg_acc}\n")
        print("Validation")
        print(f"Loss: {val_result[0]}\nAccuracy: Each_label_acc->{val_result[2]}\n\tAvg acc->{val_result[1]}")
        print("-"*50)
        
        train["avg"].append(avg_acc)
        train["individual"].append(epoch_acc)
        train["loss"].append(avg_loss)
        
        val["avg"].append(val_result[1])
        val["individual"].append(val_result[2])
        val["loss"].append(val_result[0])
        if val_result[0]<max_loss:
            max_loss=val_result[0]
            torch.save(model.state_dict(),'multimodal.bin')

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    return train,val,best_model,max_loss

In [31]:
train,val,best_model,best_acc=train_model(multimodel, loss_fn, optim,train_loader,val_loader,num_epochs=140)

Epoch 1 of 140
Train
Loss: 2.1029217026450415
Accuracy: Each_label_acc->{'phylum_label': 0.49348262032085566, 'class_label': 0.6064505347593583, 'species_label': 0.4455213903743315, 'form_label': 0.272225935828877, 'sample_label': 0.2369652406417112}
	Avg acc->0.41092914438502676

Validation
Loss: 2.063020348548889
Accuracy: Each_label_acc->{'phylum_label': 0.665296052631579, 'class_label': 0.7014802631578947, 'species_label': 0.3873355263157895, 'form_label': 0.37746710526315785, 'sample_label': 0.3404605263157895}
	Avg acc->0.4944078947368421
--------------------------------------------------
Epoch 2 of 140
Train
Loss: 2.01131948557767
Accuracy: Each_label_acc->{'phylum_label': 0.6099598930481284, 'class_label': 0.6799799465240641, 'species_label': 0.6049465240641712, 'form_label': 0.3467580213903743, 'sample_label': 0.3140040106951872}
	Avg acc->0.511129679144385

Validation
Loss: 1.9759199023246765
Accuracy: Each_label_acc->{'phylum_label': 0.743421052631579, 'class_label': 0.77467

In [32]:
multimodel.load_state_dict(torch.load('/content/multimodal.bin'))
multimodel.to(device)
multimodel.eval()
val_loader = DataLoader(val_dataset,batch_size=len(val_dataset))

target = next(iter(val_loader))
img,labels = target['image'], target['labels']

img = img.to(device)
pred = multimodel(img)

In [33]:
for x in pred:
  pred[x] = pred[x].cpu().detach().argmax(axis = 1).numpy()
  labels[x] = labels[x].numpy()

In [34]:
from sklearn.metrics import classification_report, confusion_matrix

for label in pred:
  classes = val_loader.dataset.label_encoder[label].classes_
  print(f"\nClassification Report for {label}:\n")
  print(classification_report(labels[label],pred[label],target_names=classes))
  print('='*95)

KeyError: ignored

In [71]:
pred

{'class': tensor([2, 5, 0, 3, 2, 3, 1, 0, 5, 2, 6, 2, 3, 0, 0, 0, 5, 2, 3, 3, 2, 3, 2, 0,
         0, 1, 6, 6, 4, 0, 0, 6, 2, 0, 2, 0, 2, 3, 2, 0, 3, 0, 5, 0, 0, 2, 0, 2,
         0, 3, 0]),
 'form': tensor([ 3,  1, 11,  6,  0,  6,  3, 11,  6,  0,  1,  0,  6, 11, 11, 11, 11,  0,
          6,  6,  0,  3,  3, 11, 11,  3,  1,  1,  3,  8, 11,  1,  5,  8,  3, 10,
          5,  6,  3, 11,  6, 11, 11, 11, 11,  5, 11,  0, 11,  6, 11]),
 'phylum': tensor([2, 0, 1, 2, 2, 1, 3, 1, 0, 2, 4, 2, 1, 1, 1, 1, 0, 2, 1, 1, 2, 2, 2, 1,
         1, 3, 4, 4, 2, 1, 1, 4, 2, 1, 2, 1, 2, 1, 2, 1, 1, 1, 0, 1, 1, 2, 1, 2,
         1, 1, 1]),
 'sample': tensor([1, 9, 0, 9, 3, 6, 1, 0, 9, 3, 9, 1, 1, 0, 0, 0, 9, 1, 1, 1, 2, 1, 1, 0,
         0, 1, 9, 9, 2, 0, 0, 9, 1, 0, 1, 0, 1, 1, 2, 0, 1, 0, 9, 0, 0, 1, 0, 1,
         0, 1, 0]),
 'species': tensor([ 1,  5, 11, 12, 14, 12,  4, 10,  5, 14,  7,  1,  3, 11, 10, 11,  5,  1,
          3,  3,  6, 12,  1,  8, 11,  4,  7,  7, 15,  8, 10,  7, 13,  8,  1,  9,
          0