In [1]:
import torchxrayvision as xrv
import skimage, torch, torchvision
from pathlib import Path

In [2]:
xrv_224 = "densenet121-res224-all"
xrv_512 = "resnet50-res512-all"
xrv_224_chex ="densenet121-res224-chex"
model_xrv_224 = xrv.models.DenseNet(weights=xrv_224)
model_xrv_512 = xrv.models.ResNet(weights=xrv_512)
model_xrv_224_chex = xrv.models.DenseNet(weights=xrv_224_chex)

In [16]:
def process_image(model,image_path):
    img = skimage.io.imread(image_path)
    img = xrv.datasets.normalize(img, 255) # convert 8-bit image to [-1024, 1024] range
    img = img[None,...] 


    if "224" in model:
        transform = torchvision.transforms.Compose([xrv.datasets.XRayCenterCrop(),xrv.datasets.XRayResizer(224)])
    else:
        transform = torchvision.transforms.Compose([xrv.datasets.XRayCenterCrop(),xrv.datasets.XRayResizer(512)])
    img = transform(img)
    img = torch.from_numpy(img).unsqueeze(0)

    return img

In [20]:
test_png_dset_path = '/vol/biodata/data/chest_xray/VinDr-CXR/1.0.0_png_512/raw/test/688ecdb1a4e994d42b5a50a8c4a9736f.png'
img = process_image(xrv_224,test_png_dset_path)

outputs_224 = model_xrv_224(img)
# print keys where value is greater than 0.5
print([k for k,v in dict(zip(model_xrv_224.pathologies,outputs_224[0].detach().numpy())).items() if v > 0.5])

outputs_224_chex = model_xrv_224_chex(img)
print([k for k,v in dict(zip(model_xrv_224_chex.pathologies,outputs_224_chex[0].detach().numpy())).items() if v > 0.5])

img = process_image(xrv_512,test_png_dset_path)
outputs_512 = model_xrv_512(img)
print([k for k,v in dict(zip(model_xrv_512.pathologies,outputs_512[0].detach().numpy())).items() if v > 0.5])


['Infiltration', 'Cardiomegaly']
['Cardiomegaly', 'Enlarged Cardiomediastinum']
['Lung Opacity']


In [28]:
vindr_test_test_split_path = '/vol/biomedic3/bglocker/ugproj2324/nns20/datasets/VinDr-CXR/test_set_three_splits/VinDr_test_test_split.txt'
vindr_dir_test_path = Path('/vol/biodata/data/chest_xray/VinDr-CXR/1.0.0_png_512/raw/test/')

In [40]:
img_outputs_224 = {}
img_outputs_224_chex = {}
img_outputs_512 = {}

for line in open(vindr_test_test_split_path):

    img_id = line.strip()
    img_path = vindr_dir_test_path /f'{img_id}.png'
    img = process_image(xrv_224,img_path)
    outputs_224 = model_xrv_224(img)
    outputs_224 = [k for k,v in dict(zip(model_xrv_224.pathologies,outputs_224[0].detach().numpy())).items() if v > 0.5]
    img_outputs_224[img_id] = outputs_224
    
    outputs_224_chex = model_xrv_224_chex(img)
    outputs_224_chex = [k for k,v in dict(zip(model_xrv_224_chex.pathologies,outputs_224_chex[0].detach().numpy())).items() if v > 0.5]
    img_outputs_224_chex[img_id] = outputs_224_chex

    img = process_image(xrv_512,img_path)
    outputs_512 = model_xrv_512(img)
    outputs_512 = [k for k,v in dict(zip(model_xrv_512.pathologies,outputs_512[0].detach().numpy())).items() if v > 0.5]
    img_outputs_512[img_id] = outputs_512

 
    

In [41]:
# write to three files
out_dir = Path('/vol/biomedic3/bglocker/ugproj2324/nns20/torchxrayvision/evaluations/VinDr_evaluation_results')
out_dir.mkdir(exist_ok=True,parents=True)
out_224_path = out_dir / 'xrv_224.txt'
out_224_chex_path = out_dir / 'xrv_224_chex.txt'
out_512_path = out_dir / 'xrv_512.txt'

with open(out_224_path,'w') as f:
    for img_id,outputs in img_outputs_224.items():
        f.write(f'{img_id},{",".join(outputs)}\n')

with open(out_224_chex_path,'w') as f:
    for img_id,outputs in img_outputs_224_chex.items():
        f.write(f'{img_id},{",".join(outputs)}\n')

with open(out_512_path,'w') as f:
    for img_id,outputs in img_outputs_512.items():
        f.write(f'{img_id},{",".join(outputs)}\n')

In [24]:
out_dir = Path('/vol/biomedic3/bglocker/ugproj2324/nns20/torchxrayvision/evaluations/VinDr_evaluation_results')
out_dir.mkdir(exist_ok=True,parents=True)
out_224_path = out_dir / 'xrv_224.txt'
out_224_chex_path = out_dir / 'xrv_224_chex.txt'
out_512_path = out_dir / 'xrv_512.txt'
output_files = [out_224_path,out_224_chex_path,out_512_path]
# for all files if they have length 1 when splitting by "," add a "No finding" to the line
for out_path in output_files:
    lines_to_write = []
    with open(out_path,'r') as f:
        # lines = f.readlines()
        for i, line in enumerate(f):
            if len(line.strip().split(",")) == 2 and len(line.strip().split(",")[1]) == 0:
                lines_to_write.append(i)

    # for all lines that need to be written to, write "No finding" to the line
    with open(out_path,'r') as f:
        lines = f.readlines()
        with open(out_path,'w') as f:
            for i, line in enumerate(lines):
                if i in lines_to_write:
                    f.write(f'{line.strip()}No finding\n')
                else:
                    f.write(line)
