<a href="https://colab.research.google.com/github/AmbiTyga/Bio-VI-BERT/blob/main/Baselines-Multi-label-Classifier.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install git+https://github.com/rwightman/pytorch-image-models.git -q

  Building wheel for timm (setup.py) ... [?25l[?25hdone


In [2]:
!wget https://raw.githubusercontent.com/AmbiTyga/Bio-VI-BERT/main/Train.7z
!7z x /content/Train.7z

--2021-04-19 10:58:48--  https://raw.githubusercontent.com/AmbiTyga/Bio-VI-BERT/main/Train.7z
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 13524763 (13M) [application/octet-stream]
Saving to: ‘Train.7z’


2021-04-19 10:58:49 (36.5 MB/s) - ‘Train.7z’ saved [13524763/13524763]


7-Zip [64] 16.02 : Copyright (c) 1999-2016 Igor Pavlov : 2016-05-21
p7zip Version 16.02 (locale=en_US.UTF-8,Utf16=on,HugeFiles=on,64 bits,2 CPUs Intel(R) Xeon(R) CPU @ 2.30GHz (306F0),ASM,AES-NI)

Scanning the drive for archives:
  0M Scan /content/                   1 file, 13524763 bytes (13 MiB)

Extracting archive: /content/Train.7z
--
Path = /content/Train.7z
Type = 7z
Physical Size = 13524763
Headers Size = 13084
Method = LZMA2:24
Solid = 

In [3]:
!wget https://raw.githubusercontent.com/AmbiTyga/Bio-VI-BERT/main/Val.7z
!7z x /content/Val.7z

--2021-04-19 10:58:51--  https://raw.githubusercontent.com/AmbiTyga/Bio-VI-BERT/main/Val.7z
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 4593856 (4.4M) [application/octet-stream]
Saving to: ‘Val.7z’


2021-04-19 10:58:52 (22.6 MB/s) - ‘Val.7z’ saved [4593856/4593856]


7-Zip [64] 16.02 : Copyright (c) 1999-2016 Igor Pavlov : 2016-05-21
p7zip Version 16.02 (locale=en_US.UTF-8,Utf16=on,HugeFiles=on,64 bits,2 CPUs Intel(R) Xeon(R) CPU @ 2.30GHz (306F0),ASM,AES-NI)

Scanning the drive for archives:
  0M Scan /content/                   1 file, 4593856 bytes (4487 KiB)

Extracting archive: /content/Val.7z
--
Path = /content/Val.7z
Type = 7z
Physical Size = 4593856
Headers Size = 5093
Method = LZMA2:6m
Solid = +
Blocks = 1


In [4]:
import timm
import pandas as pd
import numpy as np
import torch.nn as nn
import torch
from torch.utils.data import Dataset, DataLoader, sampler
from sklearn.preprocessing import LabelEncoder
import os,time
from math import fsum
from sklearn.model_selection import train_test_split
from glob import glob
from PIL import Image
from torchvision import transforms
import torch.functional as F

In [None]:
!rm -r train val

## Cleaning and getting a final data

In [5]:
from PIL import Image
from glob import glob
images = [x for x in glob('./val/*/*') if 'val.csv' not in x]
for i in images:
  Image.open(i).convert('RGB').save(i)

In [6]:
labels = ["phylum","class","species","form","sample"]

train = pd.read_csv('/content/train/train.csv')
val = pd.read_csv('/content/val/val.csv')

In [7]:
label_encoders = dict()
for x in labels:
  le = LabelEncoder()
  train[f'{x}_label'] = le.fit_transform(train[x])
  val[f'{x}_label'] = le.transform(val[x])
  label_encoders[x] = le

In [8]:
train.to_csv('/content/train.csv',index = False)

val.to_csv('/content/val.csv',index = False)

In [None]:
for col in labels:
  print(50*'-'+col+50*'-')
  print(train[col].nunique(),val[col].nunique(),sep=' -> ')

--------------------------------------------------phylum--------------------------------------------------
5 -> 5
--------------------------------------------------class--------------------------------------------------
7 -> 7
--------------------------------------------------species--------------------------------------------------
16 -> 16
--------------------------------------------------form--------------------------------------------------
9 -> 9
--------------------------------------------------sample--------------------------------------------------
6 -> 6


# Dataset Objects

In [10]:
class MultiDataset(Dataset):
  def __init__(self,dataFrame,label_encoders,transform=None):
    self.data = dataFrame
    self.transform = transform
    self.labels = ["phylum_label","class_label","species_label","form_label","sample_label"]
    # self.labels = ["species_label"]
    self.label_encoder = label_encoders
    self.nclasses = self.get_nclasses()
    total = fsum(self.nclasses.values())

    self.weights = {k:v/total for k,v in self.nclasses.items()}

  def __len__(self):
    return len(self.data)

  def __getitem__(self, idx):
    img_path = self.data.loc[idx,"img_path"]
    image = Image.open(img_path)
    image = self.transform(image)
    labels = self.data.loc[idx,self.labels].to_dict()
    for label in labels:
      labels[label] = torch.tensor(labels[label],dtype = torch.long)
            
    dict_data = {
    'image': image,
    'labels': labels
    }
    return dict_data
  
  # def get_label_encode(self):
  #   d=dict()
  #   for label in self.labels:            
  #       le = LabelEncoder()
  #       self.data[label]=le.fit_transform(self.data[label])
  #       d[label]=le
  #   return d
    
  def get_nclasses(self):
    return self.data[self.labels].nunique().to_dict()

# BaseLines

## Resnet200d Multi-label Classification

In [None]:
class MultiResNet(nn.Module):
  def __init__(self,nclasses,labels):
    super(MultiResNet,self).__init__()


    self.img_transformer = timm.models.resnet.resnet200d(pretrained=True,num_classes = 0)
    
    self.dividers = nn.ModuleList()

    for label in labels:
      self.dividers.append(nn.Sequential(nn.Linear(2048,1024),
                                         nn.ReLU(),
                                         nn.Linear(1024,512)))

    
    self.fc1 = nn.Linear(512,3*512,bias = False)
    self.mha = nn.MultiheadAttention(embed_dim = 512,
                                     num_heads = 8)
    self.classifiers = nn.ModuleDict()

    for label in labels:
      self.classifiers[label]=nn.Sequential(nn.BatchNorm1d(512),
                                            nn.Linear(512,nclasses[label]))

    
  def forward(self,x):
    x=self.img_transformer(x)
    # x = nn.ReLU()(x)
    z = []
    for i in range(len(self.dividers)):
      l = self.dividers[i](x)
      l = l.unsqueeze(dim = 0)
      z.append(l)
    x = torch.cat(z,dim = 0)
    x = self.fc1(x)
    q, k, v = x.split(512,dim = 2)
    z,_ = self.mha(q,k,v)
    
    y=dict()
    for label,c in zip(self.classifiers.keys(),z):
        y[label]=self.classifiers[label](c)
    return y


In [None]:
class_counts = train['species_label'].value_counts().to_dict()
weights = torch.tensor([1/class_counts[label] for label in train['species_label'].values])

train_dataset = MultiDataset(train,label_encoders,train_transformer)
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=32,
                                           sampler=torch.utils.data.WeightedRandomSampler(weights,num_samples=2160))

class_counts = val['species_label'].value_counts().to_dict()
weights = torch.tensor([1/class_counts[label] for label in val['species_label'].values])

val_dataset = MultiDataset(val,label_encoders,val_transformer)
val_loader = torch.utils.data.DataLoader(val_dataset,batch_size=16,
                                         sampler=torch.utils.data.WeightedRandomSampler(weights,num_samples=720))

In [None]:
device = torch.device('cuda')
multimodel = MultiResNet(train_dataset.nclasses,train_dataset.labels)
multimodel.to(device)
loss_fn = nn.CrossEntropyLoss()
optim = torch.optim.Adam(filter(lambda p: p.requires_grad, multimodel.parameters()),lr = 1e-4)

In [None]:
#@title Training Functions
def get_loss(out,labels,loss_fn,weights):
    loss=0
    for label in labels:
        loss+=loss_fn(out[label],labels[label])*weights[label]
    return loss


def cal_accuracy(out,labels,batch_size,epoch_acc):
    for key in out:
        q=(out[key].detach().argmax(axis=1)==labels[key]).sum().item()/batch_size
        if not key in epoch_acc.keys():
            epoch_acc[key]=q
        else:
            epoch_acc[key]+=q
    return epoch_acc


def get_avg_acc(epoch_acc,loader):
    acc=0
    for key in epoch_acc:
        epoch_acc[key]/=len(loader)
        acc+=epoch_acc[key]
        
    return acc/len(epoch_acc),epoch_acc

def evaluate(model,loss_fn,loader):
    model.eval()
    epoch_acc={}
    epoch_loss=0
    with torch.no_grad():
      for batch in loader:
          img=batch["image"].to(device)
          labels=batch["labels"]
          for key in labels:
              labels[key]=labels[key].to(device)
          
          out=model(img)
          loss=get_loss(out,labels,loss_fn,loader.dataset.weights)
          epoch_loss+=loss.item()
          epoch_acc=cal_accuracy(out,labels,img.shape[0],epoch_acc)
        
    avg_acc,epoch_acc=get_avg_acc(epoch_acc,loader)
    return (epoch_loss/len(loader),avg_acc,epoch_acc)

def train_model(model, loss_fn, opt,trainloader,valloader, num_epochs=1):
    stop_value = 10
    since = time.time()
    best_model=""
    # max_loss=np.inf
    max_acc = 0.803
    train={"loss":[],"avg":[],"individual":[]}
    val={"loss":[],"avg":[],"individual":[]}
    for epoch in range(num_epochs):
        print("-"*50)
        print('Epoch {} of {}'.format(epoch+1, num_epochs))
        epoch_loss=0
        epoch_acc={}
        model.train()
        # Iterate over data.
        for i,data in enumerate(trainloader):
            opt.zero_grad()
            img=data["image"].to(device)
            labels=data["labels"]
            for key in labels:
                labels[key]=labels[key].to(device)
            

            out=model(img)
            # print(out,labels,sep = '\n\n')
            loss=get_loss(out,labels,loss_fn,trainloader.dataset.weights)
            epoch_loss+=loss.detach().item()
            loss.backward()
            opt.step()
            epoch_acc=cal_accuracy(out,labels,img.shape[0],epoch_acc)
            
        val_result=evaluate(model,loss_fn,valloader)
        avg_acc,epoch_acc=get_avg_acc(epoch_acc,trainloader)
        avg_loss=epoch_loss/len(trainloader)
        print('Train')
        print(f"Loss: {avg_loss:.3f}\nAccuracy:\n\tEach_label_acc->{epoch_acc}\n\tAvg acc->{avg_acc:.3f}\n")
        print("-"*100)
        print("Validation")
        print(f"Loss: {val_result[0]:.3f}\nAccuracy: Each_label_acc->{val_result[2]}\n\tAvg acc->{val_result[1]:.3f}")
        
        if val_result[1]>max_acc:
          train["avg"].append(avg_acc)
          train["individual"].append(epoch_acc)
          train["loss"].append(avg_loss)
        
          val["avg"].append(val_result[1])
          val["individual"].append(val_result[2])
          val["loss"].append(val_result[0])
        
          max_acc=val_result[1]
          torch.save(model.state_dict(),'multimodal.bin')
          print(f'Saving weights with avg_acc -> {max_acc:.3f}')

          stop_value = 10
        if val_result[1]<max_acc:
          stop_value -=1
          print(f'Patience - {stop_value} max_avg_acc -> {max_acc:.3f} and epoch_acc -> {val_result[1]:.3f}')

        if stop_value<=0:
          print("\nStopping")
          return train,val,best_model,max_acc
        print("-"*100)
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    return train,val,best_model,max_loss

In [None]:
train_out,val_out,best_model,best_acc=train_model(multimodel, loss_fn, optim,train_loader,val_loader,num_epochs=140)

--------------------------------------------------
Epoch 1 of 140
Train
Loss: 0.027
Accuracy:
	Each_label_acc->{'phylum_label': 1.0, 'class_label': 1.0, 'species_label': 0.9995404411764706, 'form_label': 0.9958639705882353, 'sample_label': 1.0}
	Avg acc->0.999

----------------------------------------------------------------------------------------------------
Validation
Loss: 0.783
Accuracy: Each_label_acc->{'phylum_label': 0.9127789046653144, 'class_label': 0.8742393509127789, 'species_label': 0.6779355420329053, 'form_label': 0.7554654045526259, 'sample_label': 0.8275862068965518}
	Avg acc->0.810
Saving weights with avg_acc -> 0.810
----------------------------------------------------------------------------------------------------
--------------------------------------------------
Epoch 2 of 140
Train
Loss: 0.030
Accuracy:
	Each_label_acc->{'phylum_label': 0.9995404411764706, 'class_label': 1.0, 'species_label': 1.0, 'form_label': 0.9940257352941176, 'sample_label': 1.0}
	Avg acc->

In [None]:
multimodel = MultiResNet(train_dataset.nclasses,train_dataset.labels)
multimodel.load_state_dict(torch.load('/content/multimodal.bin'))
multimodel.to(device)
preds, true = [],[]
val_loader = DataLoader(val_dataset,batch_size=17)
multimodel.eval()
with torch.no_grad():
  for batch in val_loader:
    img, labels = batch['image'], batch['labels']
    img = img.to(device)
    pred = multimodel(img)
    preds.append(pred)
    true.append(labels)


In [None]:
y_pred = {'phylum_label':[], 'class_label':[], 'species_label':[], 'form_label':[], 'sample_label':[]}
y_true = {'phylum_label':[], 'class_label':[], 'species_label':[], 'form_label':[], 'sample_label':[]}

In [None]:
for i in range(len(preds)):
  for key in y_pred:
    y_pred[key].extend(preds[i][key].argmax(axis = 1).cpu().tolist())
    y_true[key].extend(true[i][key].tolist())

In [None]:
from sklearn.metrics import classification_report, confusion_matrix
y = '_label'
for label in y_pred:
  label_le = label.replace(y,'')
  classes = label_encoders[label_le].inverse_transform(np.unique(y_true[label]))
  print(f"\nClassification Report for {label}:\n")
  # try:
  print(classification_report(y_true[label],y_pred[label],target_names=classes))
  # except Exception as e:
  #   print(classification_report(y_true[label],y_pred[label]))
  print('='*95)


Classification Report for phylum_label:

                   precision    recall  f1-score   support

        Amoebozoa       0.89      0.89      0.89        35
      Apicomplexa       0.94      1.00      0.97       245
         Nematoda       0.91      0.97      0.94       120
  Platyhelminthes       0.92      0.70      0.80        50
Sarcomastigophora       1.00      0.69      0.81        35

         accuracy                           0.93       485
        macro avg       0.93      0.85      0.88       485
     weighted avg       0.93      0.93      0.92       485


Classification Report for class_label:

               precision    recall  f1-score   support

 Aconoidasida       0.93      1.00      0.96       165
      Cestoda       0.97      0.72      0.83        50
  Chromadorea       0.81      0.85      0.83        95
  Conoidasida       0.91      0.96      0.93        80
      Enoplea       0.54      0.56      0.55        25
    Tubulinea       0.91      0.86      0.88        

## EfficientNet B0 Multi-label Classification

In [22]:
transformer = timm.models.efficientnet_b0(pretrained=False,num_classes = 0)
x = torch.randn(2,3,224,224)
with torch.no_grad():
  a= transformer(x)

In [23]:
a.size()

torch.Size([2, 1280])

In [34]:
class MultiEffNet(nn.Module):
  def __init__(self,nclasses,labels):
    super(MultiEffNet,self).__init__()


    self.img_transformer = timm.models.efficientnet_b0(pretrained=True,num_classes = 0)
    
    self.dividers = nn.ModuleList()

    for label in labels:
      self.dividers.append(nn.Linear(1280,640))

    
    self.fc1 = nn.Linear(640,3*640,bias = False)
    self.mha = nn.MultiheadAttention(embed_dim = 640,
                                     num_heads = 8)
    self.classifiers = nn.ModuleDict()

    for label in labels:
      self.classifiers[label]=nn.Sequential(nn.BatchNorm1d(640),
                                            nn.Linear(640,512),
                                            nn.ReLU(),
                                            nn.Linear(512,nclasses[label]))

    
  def forward(self,x):
    x=self.img_transformer(x)
    # x = nn.ReLU()(x)
    z = []
    for i in range(len(self.dividers)):
      l = self.dividers[i](x)
      l = l.unsqueeze(dim = 0)
      z.append(l)
    x = torch.cat(z,dim = 0)
    x = self.fc1(x)
    q, k, v = x.split(640,dim = 2)
    z,_ = self.mha(q,k,v)
    
    y=dict()
    for label,c in zip(self.classifiers.keys(),z):
        y[label]=self.classifiers[label](c)
    return y


In [19]:
train_transformer = transforms.Compose([
        transforms.Resize((224,224)),
#         transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
    ])

val_transformer = transforms.Compose([
        transforms.Resize((224,224)),
#         transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
    ])


In [20]:
class_counts = train['species_label'].value_counts().to_dict()
weights = torch.tensor([1/class_counts[label] for label in train['species_label'].values])

train_dataset = MultiDataset(train,label_encoders,train_transformer)
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=32,
                                           sampler=torch.utils.data.WeightedRandomSampler(weights,num_samples=2160))

class_counts = val['species_label'].value_counts().to_dict()
weights = torch.tensor([1/class_counts[label] for label in val['species_label'].values])

val_dataset = MultiDataset(val,label_encoders,val_transformer)
val_loader = torch.utils.data.DataLoader(val_dataset,batch_size=16,
                                         sampler=torch.utils.data.WeightedRandomSampler(weights,num_samples=720))

In [36]:
device = torch.device('cuda')
multimodel = MultiEffNet(train_dataset.nclasses,train_dataset.labels)
multimodel.to(device)
loss_fn = nn.CrossEntropyLoss()
optim = torch.optim.Adam(filter(lambda p: p.requires_grad, multimodel.parameters()),lr = 1e-4)

In [37]:
#@title Training Functions
def get_loss(out,labels,loss_fn,weights):
    loss=0
    for label in labels:
        loss+=loss_fn(out[label],labels[label])*weights[label]
    return loss


def cal_accuracy(out,labels,batch_size,epoch_acc):
    for key in out:
        q=(out[key].detach().argmax(axis=1)==labels[key]).sum().item()/batch_size
        if not key in epoch_acc.keys():
            epoch_acc[key]=q
        else:
            epoch_acc[key]+=q
    return epoch_acc


def get_avg_acc(epoch_acc,loader):
    acc=0
    for key in epoch_acc:
        epoch_acc[key]/=len(loader)
        acc+=epoch_acc[key]
        
    return acc/len(epoch_acc),epoch_acc

def evaluate(model,loss_fn,loader):
    model.eval()
    epoch_acc={}
    epoch_loss=0
    with torch.no_grad():
      for batch in loader:
          img=batch["image"].to(device)
          labels=batch["labels"]
          for key in labels:
              labels[key]=labels[key].to(device)
          
          out=model(img)
          loss=get_loss(out,labels,loss_fn,loader.dataset.weights)
          epoch_loss+=loss.item()
          epoch_acc=cal_accuracy(out,labels,img.shape[0],epoch_acc)
        
    avg_acc,epoch_acc=get_avg_acc(epoch_acc,loader)
    return (epoch_loss/len(loader),avg_acc,epoch_acc)

def train_model(model, loss_fn, opt,trainloader,valloader, num_epochs=1):
    stop_value = 10
    since = time.time()
    best_model=""
    # max_loss=np.inf
    max_acc = -np.inf
    train={"loss":[],"avg":[],"individual":[]}
    val={"loss":[],"avg":[],"individual":[]}
    for epoch in range(num_epochs):
        print("-"*50)
        print('Epoch {} of {}'.format(epoch+1, num_epochs))
        epoch_loss=0
        epoch_acc={}
        model.train()
        # Iterate over data.
        for i,data in enumerate(trainloader):
            opt.zero_grad()
            img=data["image"].to(device)
            labels=data["labels"]
            for key in labels:
                labels[key]=labels[key].to(device)
            

            out=model(img)
            # print(out,labels,sep = '\n\n')
            loss=get_loss(out,labels,loss_fn,trainloader.dataset.weights)
            epoch_loss+=loss.detach().item()
            loss.backward()
            opt.step()
            epoch_acc=cal_accuracy(out,labels,img.shape[0],epoch_acc)
            
        val_result=evaluate(model,loss_fn,valloader)
        avg_acc,epoch_acc=get_avg_acc(epoch_acc,trainloader)
        avg_loss=epoch_loss/len(trainloader)
        print('Train')
        print(f"Loss: {avg_loss:.3f}\nAccuracy:\n\tEach_label_acc->{epoch_acc}\n\tAvg acc->{avg_acc:.3f}\n")
        print("-"*100)
        print("Validation")
        print(f"Loss: {val_result[0]:.3f}\nAccuracy: Each_label_acc->{val_result[2]}\n\tAvg acc->{val_result[1]:.3f}")
        
        if val_result[1]>max_acc:
          train["avg"].append(avg_acc)
          train["individual"].append(epoch_acc)
          train["loss"].append(avg_loss)
        
          val["avg"].append(val_result[1])
          val["individual"].append(val_result[2])
          val["loss"].append(val_result[0])
        
          max_acc=val_result[1]
          torch.save(model.state_dict(),'multimodal.bin')
          print(f'Saving weights with avg_acc -> {max_acc:.3f}')

          stop_value = 10
        if val_result[1]<max_acc:
          stop_value -=1
          print(f'Patience - {stop_value} max_avg_acc -> {max_acc:.3f} and epoch_acc -> {val_result[1]:.3f}')

        if stop_value<=0:
          print("\nStopping")
          return train,val,best_model,max_acc
        print("-"*100)
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    return train,val,best_model,max_loss

In [38]:
train_out,val_out,best_model,best_acc=train_model(multimodel, loss_fn, optim,train_loader,val_loader,num_epochs=140)

--------------------------------------------------
Epoch 1 of 140
Train
Loss: 1.215
Accuracy:
	Each_label_acc->{'phylum_label': 0.7660845588235294, 'class_label': 0.7090992647058824, 'species_label': 0.5795036764705882, 'form_label': 0.6190257352941176, 'sample_label': 0.7426470588235294}
	Avg acc->0.683

----------------------------------------------------------------------------------------------------
Validation
Loss: 1.088
Accuracy: Each_label_acc->{'phylum_label': 0.7847222222222222, 'class_label': 0.7263888888888889, 'species_label': 0.5819444444444445, 'form_label': 0.6166666666666667, 'sample_label': 0.7680555555555556}
	Avg acc->0.696
Saving weights with avg_acc -> 0.696
----------------------------------------------------------------------------------------------------
--------------------------------------------------
Epoch 2 of 140
Train
Loss: 0.347
Accuracy:
	Each_label_acc->{'phylum_label': 0.9747242647058824, 'class_label': 0.9618566176470589, 'species_label': 0.90762867

In [39]:
val_out['individual'][-1]

{'class_label': 0.8666666666666667,
 'form_label': 0.7375,
 'phylum_label': 0.9097222222222222,
 'sample_label': 0.8388888888888889,
 'species_label': 0.7194444444444444}

In [40]:
multimodel = MultiEffNet(train_dataset.nclasses,train_dataset.labels)
multimodel.load_state_dict(torch.load('/content/multimodal.bin'))
multimodel.to(device)
preds, true = [],[]
val_loader = DataLoader(val_dataset,batch_size=17)
multimodel.eval()
with torch.no_grad():
  for batch in val_loader:
    img, labels = batch['image'], batch['labels']
    img = img.to(device)
    pred = multimodel(img)
    preds.append(pred)
    true.append(labels)


In [41]:
y_pred = {'phylum_label':[], 'class_label':[], 'species_label':[], 'form_label':[], 'sample_label':[]}
y_true = {'phylum_label':[], 'class_label':[], 'species_label':[], 'form_label':[], 'sample_label':[]}

In [42]:
for i in range(len(preds)):
  for key in y_pred:
    y_pred[key].extend(preds[i][key].argmax(axis = 1).cpu().tolist())
    y_true[key].extend(true[i][key].tolist())

In [43]:
from sklearn.metrics import classification_report, confusion_matrix
y = '_label'
for label in y_pred:
  label_le = label.replace(y,'')
  classes = label_encoders[label_le].inverse_transform(np.unique(y_true[label]))
  print(f"\nClassification Report for {label}:\n")
  # try:
  print(classification_report(y_true[label],y_pred[label],target_names=classes))
  # except Exception as e:
  #   print(classification_report(y_true[label],y_pred[label]))
  print('='*95)


Classification Report for phylum_label:

                   precision    recall  f1-score   support

        Amoebozoa       0.83      0.86      0.85        35
      Apicomplexa       0.95      0.97      0.96       245
         Nematoda       0.87      0.95      0.91       120
  Platyhelminthes       0.78      0.58      0.67        50
Sarcomastigophora       0.81      0.71      0.76        35

         accuracy                           0.90       485
        macro avg       0.85      0.81      0.83       485
     weighted avg       0.90      0.90      0.89       485


Classification Report for class_label:

               precision    recall  f1-score   support

 Aconoidasida       0.98      0.96      0.97       165
      Cestoda       0.82      0.56      0.67        50
  Chromadorea       0.79      0.88      0.83        95
  Conoidasida       0.81      0.93      0.87        80
      Enoplea       0.68      0.60      0.64        25
    Tubulinea       0.81      0.86      0.83        