# CXR-ML-GZSL

## Overview

The goal of this notebook is to reproduce the findings of the paper, "Multi-Label Generalized Zero Shot Learning for the Classification of Disease in Chest Radiographs" with the help of an LLM. The paper provides code, but this notebook will only reuse the provided split of the dataset into training, test, and validation data.

* Paper: https://arxiv.org/abs/2107.06563
* Dataset split: https://github.com/nyuad-cai/CXR-ML-GZSL/tree/master/dataset_splits

The paper used a dataset, initially known as `ChestX-ray8`, but then renamed to `ChestX-ray14` when the dataset was expanded from eight to fourteen distinct disease labels. The dataset contains 112120 labeled chest X-rays.

* Dataset paper: https://arxiv.org/abs/1705.02315
* Dataset: https://nihcc.app.box.com/v/ChestXray-NIHCC/folder/36938765345

The dataset provides an example of how to download the chext X-ray images and a spreadsheet mapping images to classification labels. I could not figure out a way to download the labels spreadsheet in code, so I saved a copy to a publically available Google Drive link.

* Download script: https://nihcc.app.box.com/v/ChestXray-NIHCC/file/371647823217
* Labels: https://nihcc.app.box.com/v/ChestXray-NIHCC/file/219760887468
* Labels (copy): https://drive.google.com/file/d/1mkOZNfYt-Px52b8CJZJANNbM3ULUVO3f/view?usp=drive_link

**Note**: The dataset is 42+ GB. Expect significant download times.

## Environment

In [9]:
!if command -v nvidia-smi &> /dev/null; then nvidia-smi --query-gpu=name --format=csv,noheader; else echo 'No NVIDIA GPU detected'; fi
!python --version

Tesla T4
Python 3.11.12


In [10]:
import os
import requests
import tarfile
import urllib.request

import torch
import torchvision

print(f"torch: {torch.__version__}")
print(f"torchvision: {torchvision.__version__}")

torch: 2.6.0+cu124
torchvision: 0.21.0+cu124


## Download

In [11]:
for filename in ["train.txt", "test.txt", "val.txt"]:
    response = requests.get(f"https://raw.githubusercontent.com/nyuad-cai/CXR-ML-GZSL/master/dataset_splits/{filename}")

    with open(filename, "w") as f:
        f.write(response.text)

    print(f"Downloaded: {filename}")

Downloaded: train.txt
Downloaded: test.txt
Downloaded: val.txt


In [12]:
LABELS_FILE = "Data_Entry_2017_v2020.csv"

response = requests.get('https://drive.google.com/uc?export=download&id=1mkOZNfYt-Px52b8CJZJANNbM3ULUVO3f')
with open(LABELS_FILE, "wb") as f:
    f.write(response.content)

print(f"Downloaded: {LABELS_FILE}")

Downloaded: Data_Entry_2017_v2020.csv


In [13]:
dataset = [
    {"filename": "images_001.tar.gz", "url": "https://nihcc.box.com/shared/static/vfk49d74nhbxq3nqjg0900w5nvkorp5c.gz"},
    {"filename": "images_002.tar.gz", "url": "https://nihcc.box.com/shared/static/i28rlmbvmfjbl8p2n3ril0pptcmcu9d1.gz"} #todo: Add the rest of the dataset
]

for item in dataset:
    filename = item["filename"]
    url = item["url"]

    urllib.request.urlretrieve(url, filename)

    with tarfile.open(filename, "r:gz") as tar:
        tar.extractall()

    print(f"Downloaded: {filename}")

IMAGE_PATH = "images"
NUM_IMAGES = 112120

assert os.path.exists(IMAGE_PATH), "Dataset is not in the expected directory!"
# assert len([f for f in os.listdir(IMAGE_PATH) if os.path.isfile(os.path.join(IMAGE_PATH, f))]) == NUM_IMAGES, "Dataset is not the expected size!"

Downloaded: images_001.tar.gz
Downloaded: images_002.tar.gz


## Preprocessing

## Model

## Training

## Evaluation

## Results