In [None]:
import os
import cv2
import csv
import numpy as np
from scipy.stats import wilcoxon

import torch
import torchvision.transforms as transforms
from tqdm.notebook import tqdm

from models import SwinTransformer

### Please modify the path of pretrained weight and test data

In [None]:
pretrained_weight_path = 'DWCC.pt'
test_path = '/covid/data/test/'

## Load model

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

model = SwinTransformer(img_size=224,
                        patch_size=4,
                        in_chans=3,
                        num_classes=1,
                        embed_dim=96,
                        depths=[2, 2, 6, 2],
                        num_heads=[3, 6, 12, 24],
                        window_size=7,
                        mlp_ratio=4.0,
                        qkv_bias=True,
                        qk_scale=None,
                        drop_rate=0.0,
                        drop_path_rate=0.2,
                        ape=False,
                        patch_norm=True,
                        use_checkpoint=False,
                        device=device)

model.load_state_dict(torch.load(pretrained_weight_path), strict=False)
model.to(device)

In [None]:
m = 0.39221061670618984
s = 0.11469786773730418
t = transforms.Compose([transforms.ToPILImage(),
                        transforms.Resize((224,224)),
                        transforms.ToTensor(),
                        transforms.Normalize((m, m, m), (s, s, s))])

In [None]:
def inference(path, st=3, ed=7, model=model, t=t):
    
    model.eval()
    img_list = os.listdir(path)
    sort_index = sorted(range(len(img_list)), key=lambda k: int(img_list[k].split('.')[0]))
    ct_len = len(sort_index)
    start_idx = int(round(ct_len / 10 * st, 0))
    end_idx = int(round(ct_len / 10 * ed, 0)) + 1
    
    pop = []
    for i in range(start_idx, end_idx):
        img_path = os.path.join(path, img_list[sort_index[i]])
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = t(img).to(device).unsqueeze(0)
        output = model(img)
        pop.append(output.item())

    p_value = wilcoxon_rank_test(pop)
    if p_value < 0.05:
        return 1
    else:
        return 0

In [None]:
def wilcoxon_rank_test(pop):
    pop = np.array(pop)
    postive_pop = pop[(pop >= 1 - np.sqrt(0.2) * 2) & (pop <= 1 + np.sqrt(0.2) * 2)]
    negative_pop = pop[(pop >= -1 - np.sqrt(0.2) * 2) & (pop <= -1 + np.sqrt(0.2) * 2)]
    total_pop = len(postive_pop) + len(negative_pop)
    if total_pop == 0:
        return 1.0
    else:
        w, p = wilcoxon(np.concatenate((postive_pop, negative_pop)), alternative='greater')
        return p

In [None]:
covid = []
non_covid = []

test_folder = os.listdir(test_path)
for folder in tqdm(test_folder):
    path = os.path.join(test_path, folder)
    pred = inference(path)
    if pred == 1:
        covid.append(folder)
    else:
        non_covid.append(folder)

In [None]:
with open('result/covid.csv', 'w', encoding='UTF8') as f:
    writer = csv.writer(f)
    writer.writerow(covid)
    
with open('result/non-covid.csv', 'w', encoding='UTF8') as f:
    writer = csv.writer(f)
    writer.writerow(non_covid)