In [12]:
import cv2
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import random
import timm
import torch
from albumentations import (
    Compose,
    Normalize,
    ShiftScaleRotate,
    RandomBrightnessContrast,
    MotionBlur,
    CLAHE,
    HorizontalFlip
)
from copy import deepcopy
from torch.utils.data import Dataset
from tqdm.auto import tqdm
from sklearn.metrics import roc_auc_score, accuracy_score, precision_score, recall_score, f1_score, classification_report, confusion_matrix

In [13]:
# from google.colab import drive
# drive.mount('/content/drive')

In [14]:
# %cd drive/MyDrive/'BIOMEDIN220-F2022'/

In [15]:
MODEL_NAME = 'multilabel_efnb4_v1_cls'
dataset_path = "vinbigdata-chest-xray-resized-png-256x256"
model_path = "vinbigdata-chest-xray-resized-png-256x256/save_models"
class_weights_path = "vinbigdata-chest-xray-resized-png-256x256/class_weights.npy"

train_csv_path = os.path.join(dataset_path, 'vindrcxr_train.csv')
test_csv_path = os.path.join(dataset_path, 'vindrcxr_test.csv')
train_image_path = os.path.join(dataset_path, 'train')
test_image_path = os.path.join(dataset_path, 'test')
save_path = os.path.join(model_path, '')

print(train_image_path)
print(test_image_path)

vinbigdata-chest-xray-resized-png-256x256/train
vinbigdata-chest-xray-resized-png-256x256/test


In [16]:
!ls vinbigdata-chest-xray-resized-png-256x256

class_weights.npy  train.csv
save_models	   train_meta.csv
test		   vinbigdata-chest-xray-resized-png-256x256.zip
test.csv	   vindrcxr_test.csv
train		   vindrcxr_train.csv


In [17]:
### Code from https://github.com/Scu-sen/VinBigData-Chest-X-ray-Abnormalities-Detection

class Dataset(Dataset):
    
    def __init__(self, df, image_path, transform=None):
        self.df = df
        self.image_path = image_path
        self.transform = transform
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        labels = torch.from_numpy(
            self.df.loc[idx,np.arange(0,15).astype(str).tolist()].values.astype(float)
        ).float()

        img = cv2.imread(
            self.image_path + '/' + str(self.df.image_id[idx]) + '.png'
        )
        
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        if self.transform:
            img = self.transform(image=img)['image']
        img = torch.from_numpy(img.transpose((2, 0, 1))).float()
            
        return img, labels

In [18]:
bs = 2
lr = 1e-3
N_EPOCHS = 10

In [19]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = True

In [20]:
def eval_model(model, data_loader):
    """
    Test the model on the validation set
    
    Parameters:
        model (torch.nn.Module): The model to be trained/validated.
        data_loader (torch.utils.data.DataLoader): Dataloader object for training/validation.
        optimizer (A torch.optim class): The optimizer.
        criterion (A torch.nn.modules.loss class): The loss function. 
        
    Return: 
        avg_loss (float): The average loss.
    """
    model.eval()
    
    running_loss = 0.0
    running_n = 0
    avg_loss = 0.0
    preds_list, targets_list = [], []
    
    with torch.no_grad():
        tk = tqdm(data_loader, total=len(data_loader), position=0, leave=True)
        
        for idx, (imgs, labels) in enumerate(tk):
            imgs, labels = imgs.cuda(), labels.cuda()
            output = model(imgs)
            
            preds = torch.sigmoid(output).detach().cpu().numpy()
            labels = labels.detach().cpu().numpy()
         
            preds_list.append(preds)
            targets_list.append(labels.round().astype(int))
    preds_list = np.concatenate(preds_list,axis=0)
    return preds_list, targets_list

In [21]:
test = pd.read_csv(test_csv_path)
test_transform = Compose([
    Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0)
])

preds = np.zeros((3000,15))
for class_id in range(15):
    valset = Dataset(
        test,
        image_path=test_image_path,
        transform=test_transform
    )
    val_loader = torch.utils.data.DataLoader(
        valset, batch_size=bs, shuffle=False, num_workers=1
    )

    model = timm.create_model('tf_efficientnet_b4_ns',pretrained=True,num_classes=15).cuda()
    model.load_state_dict(torch.load(os.path.join(model_path, f'{MODEL_NAME}{class_id}_weighted.pth'))['weight'])
    preds_list, targets_list = eval_model(model, val_loader)
    preds[:,class_id] = preds_list[:,class_id]

  0%|          | 0/1500 [00:00<?, ?it/s]

  0%|          | 0/1500 [00:00<?, ?it/s]

  0%|          | 0/1500 [00:00<?, ?it/s]

  0%|          | 0/1500 [00:00<?, ?it/s]

  0%|          | 0/1500 [00:00<?, ?it/s]

  0%|          | 0/1500 [00:00<?, ?it/s]

  0%|          | 0/1500 [00:00<?, ?it/s]

  0%|          | 0/1500 [00:00<?, ?it/s]

  0%|          | 0/1500 [00:00<?, ?it/s]

  0%|          | 0/1500 [00:00<?, ?it/s]

  0%|          | 0/1500 [00:00<?, ?it/s]

  0%|          | 0/1500 [00:00<?, ?it/s]

  0%|          | 0/1500 [00:00<?, ?it/s]

  0%|          | 0/1500 [00:00<?, ?it/s]

  0%|          | 0/1500 [00:00<?, ?it/s]

In [26]:
df_preds = pd.DataFrame(preds, columns=np.arange(0,15)).assign(image_id = test['image_id'])

In [27]:
df_preds.to_csv(save_path+f'multilabel_efnb4_weighted_preds.csv', index=False)

In [31]:
def get_per_class_metrics(metric_fn, targets_list, preds_list, class_weights=None):
    targets_list = targets_list.T
    preds_list = preds_list.T
    per_class_metrics = np.array([metric_fn(i,j) if len(set(i)) > 1 else np.nan for i,j in zip(targets_list, preds_list)])
    
    if class_weights is None:
        overall_metric = np.nanmean(per_class_metrics)
    else:
        overall_metric = np.nansum(class_weights * per_class_metrics)/np.nansum(class_weights)
    return overall_metric, per_class_metrics

### Load Predictions

In [32]:
class_weights = np.load(class_weights_path)
df_preds = pd.read_csv(save_path+f'multilabel_efnb4_weighted_preds.csv')

### Compute Metrics

In [37]:
class_id_list = list(map(lambda x : str(x), list(range(15))))
targets = test[class_id_list].to_numpy() # np.concatenate(targets_list, axis=0)
overall_auc, aucs = get_per_class_metrics(roc_auc_score, targets, preds)

thresholded_preds = preds >= 0.95

overall_acc, accs = get_per_class_metrics(accuracy_score, targets, thresholded_preds, class_weights)
overall_prec, precs = get_per_class_metrics(precision_score, targets, thresholded_preds, class_weights)
overall_recall, recalls = get_per_class_metrics(recall_score, targets, thresholded_preds, class_weights)
overall_f1, f1s = get_per_class_metrics(f1_score, targets, thresholded_preds, class_weights)


print(f"Overall AUC: {overall_auc}")
print(f"Per-Class AUCs: {aucs}")
print(f"Overall Accuracy: {overall_acc}")
print(f"Per-Class Accuracies: {accs}")
print(f"Overall Precision: {overall_prec}")
print(f"Per-Class Precisions: {precs}")
print(f"Overall Recall: {overall_recall}")
print(f"Per-Class Recalls: {recalls}")
print(f"Overall F1 Score: {overall_f1}")
print(f"Per-Class F1 Scores: {f1s}")

Overall AUC: 0.9041782677715704
Per-Class AUCs: [0.87858241 0.86805478 0.85228817 0.94435605 0.9362589  0.88319474
 0.92946389 0.87950307 0.88059732 0.86079791 0.98039161 0.88051768
 0.97067591 0.8787222  0.93926939]
Overall Accuracy: 0.9616077780389338
Per-Class Accuracies: [0.92133333 0.96166667 0.93433333 0.922      0.96333333 0.92733333
 0.96466667 0.93466667 0.94166667 0.969      0.97733333 0.944
 0.987      0.933      0.79233333]
Overall Precision: 0.3476865550866833
Per-Class Precisions: [0.41111111 0.23636364 0.41176471 0.79527559 0.42553191 0.51764706
 0.26923077 0.14556962 0.50980392 0.66666667 0.66929134 0.52631579
 0.23076923 0.63793103 0.96005155]
Overall Recall: 0.308250088244502
Per-Class Recalls: [0.16818182 0.15116279 0.03608247 0.32686084 0.41666667 0.19909502
 0.48275862 0.27380952 0.14772727 0.0212766  0.76576577 0.0591716
 0.5        0.17050691 0.72647489]
Overall F1 Score: 0.27055337062964707
Per-Class F1 Scores: [0.23870968 0.18439716 0.06635071 0.46330275 0.4210

In [34]:
print(classification_report(targets, thresholded_preds))

              precision    recall  f1-score   support

           0       0.27      0.73      0.40       220
           1       0.16      0.67      0.26        86
           2       0.22      0.66      0.33       194
           3       0.52      0.80      0.63       309
           4       0.24      0.83      0.38        96
           5       0.26      0.78      0.39       221
           6       0.09      0.88      0.16        58
           7       0.09      0.95      0.16        84
           8       0.23      0.73      0.35       176
           9       0.09      0.84      0.17        94
          10       0.27      0.95      0.42       111
          11       0.21      0.75      0.33       169
          12       0.05      0.89      0.09        18
          13       0.21      0.86      0.34       217
          14       0.92      0.93      0.92      2051

   micro avg       0.35      0.86      0.49      4104
   macro avg       0.26      0.82      0.35      4104
weighted avg       0.59   

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [89]:
class_id_list = list(map(lambda x : str(x), list(range(15))))
targets = test[class_id_list].to_numpy() # np.concatenate(targets_list, axis=0)
overall_auc, aucs = get_per_class_metrics(roc_auc_score, targets, preds)

thresholded_preds = np.round(preds)

overall_acc, accs = get_per_class_metrics(accuracy_score, targets, thresholded_preds, class_weights)
overall_prec, precs = get_per_class_metrics(precision_score, targets, thresholded_preds, class_weights)
overall_recall, recalls = get_per_class_metrics(recall_score, targets, thresholded_preds, class_weights)
overall_f1, f1s = get_per_class_metrics(f1_score, targets, thresholded_preds, class_weights)


print(f"Overall AUC: {overall_auc}")
print(f"Per-Class AUCs: {aucs}")
print(f"Overall Accuracy: {overall_acc}")
print(f"Per-Class Accuracies: {accs}")
print(f"Overall Precision: {overall_prec}")
print(f"Per-Class Precisions: {precs}")
print(f"Overall Recall: {overall_recall}")
print(f"Per-Class Recalls: {recalls}")
print(f"Overall F1 Score: {overall_f1}")
print(f"Per-Class F1 Scores: {f1s}")

Overall AUC: 0.8678084607157158
Per-Class AUCs: [0.86733322 0.85309891 0.81673108 0.92776112 0.85983772 0.88237899
 0.90179681 0.8298509  0.78127616 0.84069643 0.94773902 0.86765502
 0.91431925 0.83873617 0.8879161 ]
Overall Accuracy: 0.9321333333333333
Per-Class Accuracies: [0.908      0.966      0.93233333 0.92066667 0.95       0.92933333
 0.94166667 0.93033333 0.94066667 0.95766667 0.92833333 0.93066667
 0.994      0.90566667 0.84666667]
Overall Precision: 0.3686699468401128
Per-Class Precisions: [0.37155963 0.13636364 0.2        0.68020305 0.26315789 0.55555556
 0.20304569 0.13872832 0.44444444 0.11627907 0.31428571 0.38596491
 0.5        0.37593985 0.84452144]
Overall Recall: 0.3395191954493697
Per-Class Recalls: [0.36818182 0.03488372 0.01546392 0.43365696 0.3125     0.20361991
 0.68965517 0.28571429 0.04545455 0.05319149 0.79279279 0.39053254
 0.05555556 0.46082949 0.95075573]
Overall F1 Score: 0.2980265738427365
Per-Class F1 Scores: [0.36986301 0.05555556 0.02870813 0.52964427 

In [90]:
preds_list_output = preds.T
targets_list_output = targets.T

aucs = np.array(
    [roc_auc_score(i,j) if len(set(i))>1 else np.nan for i,j in zip(targets_list_output, preds_list_output)]
)
overall_auc = np.nanmean(aucs)

thresholded_preds_list = np.round(preds_list_output)
accs = np.array(
    [accuracy_score(i,j) if len(set(i))>1 else np.nan for i,j in zip(targets_list_output, thresholded_preds_list)]
)
overall_acc = np.nanmean(accs)

print(f"Overall AUC: {overall_auc}")
print(f"Overall Accuracy: {overall_acc}")
print(f"Per-Class AUCs: {aucs}")
print(f"Per-Class Accuracies: {accs}")

Overall AUC: 0.8678084607157158
Overall Accuracy: 0.9321333333333333
Per-Class AUCs: [0.86733322 0.85309891 0.81673108 0.92776112 0.85983772 0.88237899
 0.90179681 0.8298509  0.78127616 0.84069643 0.94773902 0.86765502
 0.91431925 0.83873617 0.8879161 ]
Per-Class Accuracies: [0.908      0.966      0.93233333 0.92066667 0.95       0.92933333
 0.94166667 0.93033333 0.94066667 0.95766667 0.92833333 0.93066667
 0.994      0.90566667 0.84666667]


In [91]:
print(classification_report(targets, thresholded_preds))

              precision    recall  f1-score   support

           0       0.37      0.37      0.37       220
           1       0.14      0.03      0.06        86
           2       0.20      0.02      0.03       194
           3       0.68      0.43      0.53       309
           4       0.26      0.31      0.29        96
           5       0.56      0.20      0.30       221
           6       0.20      0.69      0.31        58
           7       0.14      0.29      0.19        84
           8       0.44      0.05      0.08       176
           9       0.12      0.05      0.07        94
          10       0.31      0.79      0.45       111
          11       0.39      0.39      0.39       169
          12       0.50      0.06      0.10        18
          13       0.38      0.46      0.41       217
          14       0.84      0.95      0.89      2051

   micro avg       0.63      0.63      0.63      4104
   macro avg       0.37      0.34      0.30      4104
weighted avg       0.62   

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [92]:
class_0_preds = (preds[:,0] >= 0.5).astype(int)
class_0_targets = targets[:,0]
cnf_matrix = confusion_matrix(class_0_targets, class_0_preds)
print("Class 0 Confusion Matrix", cnf_matrix)
tn = cnf_matrix[0][0]
fp = cnf_matrix[0][1]
fn = cnf_matrix[1][0]
tp = cnf_matrix[1][1]
precision = tp / (tp + fp)
recall = tp / (tp + fn)
print(f"Class 0 Precision: {precision}")
print(f"Class 0 Recall: {recall}")

Class 0 Confusion Matrix [[2643  137]
 [ 139   81]]
Class 0 Precision: 0.37155963302752293
Class 0 Recall: 0.36818181818181817
