In [1]:
import torch
import pandas as pd
import numpy as np
from torch.utils.data import DataLoader
from data import XrayDataset
from torchmetrics.classification import BinaryAccuracy, BinaryF1Score, BinaryRecall, BinaryPrecision, BinarySpecificity
import timm, tqdm

  from .autonotebook import tqdm as notebook_tqdm
  check_for_updates()


In [3]:
df_val = pd.read_csv("//home/ubuntu/ductq/csvs/cls_csv/val_set.csv")
val_set = XrayDataset(df_val)
val_loader = DataLoader(val_set, batch_size=16, num_workers=8, shuffle=False)

In [4]:
model = timm.create_model("tf_efficientnet_b0", num_classes = 1)

state_dict = torch.load("/home/ubuntu/ductq/results/ckpt/v5/last.ckpt")["state_dict"]
state_dict = {k.replace('model.', ''):v for k, v in state_dict.items()}

model.load_state_dict(state_dict)
model.cuda().eval()

EfficientNet(
  (conv_stem): Conv2dSame(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
  (bn1): BatchNormAct2d(
    32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
    (drop): Identity()
    (act): SiLU(inplace=True)
  )
  (blocks): Sequential(
    (0): Sequential(
      (0): DepthwiseSeparableConv(
        (conv_dw): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
        (bn1): BatchNormAct2d(
          32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): SiLU(inplace=True)
        )
        (aa): Identity()
        (se): SqueezeExcite(
          (conv_reduce): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1))
          (act1): SiLU(inplace=True)
          (conv_expand): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1))
          (gate): Sigmoid()
        )
        (conv_pw): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn2): BatchNorm

In [5]:
pred = torch.tensor([], device="cuda")
gt = torch.tensor([], device="cuda")

for batch in tqdm.tqdm(val_loader):
    x, y = batch["image"].cuda().float(), batch["label"].cuda().float()

    with torch.no_grad():
        output = model(x).float()

    predict = torch.sigmoid(output.squeeze(-1))
    
 
    pred = torch.cat((pred, predict),dim=0)
    gt = torch.cat((gt, y), dim=0)

100%|██████████| 94/94 [00:13<00:00,  7.04it/s]


In [6]:
for thr in np.arange(0.1, 0.9, 0.05):
    acc = BinaryAccuracy(threshold=thr).cuda()
    pre = BinaryPrecision(threshold=thr).cuda()
    sen = BinaryRecall(threshold=thr).cuda()
    f1 = BinaryF1Score(threshold=thr).cuda()
    spec = BinarySpecificity(threshold=thr).cuda()

    print(f"Results at {thr}: ")
    print(f"Accuracy: {acc(pred,gt)}")
    print(f"Precision: {pre(pred,gt)}")
    print(f"Recall: {sen(pred,gt)}")
    print(f"F1 Score {f1(pred,gt)}")
    print(f"Specificity {spec(pred,gt)}")
    print("-"*50)

Results at 0.1: 
Accuracy: 0.8973333239555359
Precision: 0.742397129535675
Recall: 0.9764705896377563
F1 Score 0.8434959053993225
Specificity 0.8660464882850647
--------------------------------------------------
Results at 0.15000000000000002: 
Accuracy: 0.921999990940094
Precision: 0.7972972989082336
Recall: 0.9717646837234497
F1 Score 0.8759278655052185
Specificity 0.9023255705833435
--------------------------------------------------
Results at 0.20000000000000004: 
Accuracy: 0.9286666512489319
Precision: 0.8205645084381104
Recall: 0.9576470851898193
F1 Score 0.8838219046592712
Specificity 0.9172093272209167
--------------------------------------------------
Results at 0.25000000000000006: 
Accuracy: 0.9386666417121887
Precision: 0.8475991487503052
Recall: 0.955294132232666
F1 Score 0.8982300758361816
Specificity 0.9320930242538452
--------------------------------------------------
Results at 0.30000000000000004: 
Accuracy: 0.9433333277702332
Precision: 0.8648068904876709
Recall: 0.9