# CXR-ML-GZSL

## Overview

The goal of this notebook is to reproduce the findings of the paper, "Multi-Label Generalized Zero Shot Learning for the Classification of Disease in Chest Radiographs" with the help of an LLM. The paper provides code, but this notebook only directly reuses the provided split of the dataset into training, test, and validation data. The provided code was a valuable reference during the development of this notebook.

* Paper: https://arxiv.org/abs/2107.06563
* Dataset split: https://github.com/nyuad-cai/CXR-ML-GZSL/tree/master/dataset_splits

The paper uses a dataset, initially known as `ChestX-ray8`, but then renamed to `ChestX-ray14` when the dataset was expanded from eight to fourteen distinct disease labels. The dataset contains 112120 labeled chest X-rays.

* Dataset paper: https://arxiv.org/abs/1705.02315
* Dataset: https://nihcc.app.box.com/v/ChestXray-NIHCC/folder/36938765345

The dataset provides an example of how to download the chext X-ray images and a spreadsheet mapping images to classification labels. However, the "dataset split" files already contain the image classification labels so this notebook will not use the spreadsheet.

* Download script: https://nihcc.app.box.com/v/ChestXray-NIHCC/file/371647823217
* Labels: https://nihcc.app.box.com/v/ChestXray-NIHCC/file/219760887468

**Note**: The dataset is ~42 GB. Expect significant download times.

## Environment

In [1]:
!if command -v nvidia-smi &> /dev/null; then nvidia-smi --query-gpu=name --format=csv,noheader; else echo 'No NVIDIA GPU detected'; fi
!python --version

NVIDIA A100-SXM4-40GB
Python 3.11.12


In [2]:
import multiprocessing
import os
import requests
import tarfile
import urllib.request

import PIL
from PIL import Image

print(f"PIL: {PIL.__version__}")

import torch
from torch.utils.data import DataLoader, Dataset

print(f"torch: {torch.__version__}")

import torchvision
import torchvision.transforms as transforms

print(f"torchvision: {torchvision.__version__}")

PIL: 11.1.0
torch: 2.6.0+cu124
torchvision: 0.21.0+cu124


## Download

In [3]:
for filename in ["train.txt", "val.txt", "test.txt"]:
    response = requests.get(f"https://raw.githubusercontent.com/nyuad-cai/CXR-ML-GZSL/master/dataset_splits/{filename}")

    with open(filename, "w") as f:
        f.write(response.text)

    print(f"Downloaded: {filename}")

Downloaded: train.txt
Downloaded: val.txt
Downloaded: test.txt


In [4]:
dataset = [
    {"filename": "images_001.tar.gz", "url": "https://nihcc.box.com/shared/static/vfk49d74nhbxq3nqjg0900w5nvkorp5c.gz"} #todo: Add the rest of the dataset
]

for item in dataset:
    filename = item["filename"]
    url = item["url"]

    urllib.request.urlretrieve(url, filename)

    with tarfile.open(filename, "r:gz") as tar:
        tar.extractall()

    os.remove(filename)

    print(f"Downloaded and extracted: {filename}")

IMAGE_PATH = "images"
NUM_IMAGES = 112120

assert os.path.exists(IMAGE_PATH), "Dataset is not in the expected directory!"
# assert len([f for f in os.listdir(IMAGE_PATH) if os.path.isfile(os.path.join(IMAGE_PATH, f))]) == NUM_IMAGES, "Dataset is not the expected size!"

Downloaded and extracted: images_001.tar.gz


## Preprocessing

In [5]:
class ChestXrayDataset(Dataset):
    def __init__(self, image_dir, labels_file, num_classes, transform):
        self.image_dir = image_dir
        self.num_classes = num_classes
        self.transform = transform
        self.samples = []

        with open(labels_file, 'r') as f:
            for line in f:
                parts = line.strip().split()
                image_name = os.path.basename(parts[0]) # remove <path>/ from <path>/<image_name>

                if not os.path.isfile(f"{self.image_dir}/{image_name}"): #todo: Remove once we add the full database
                    continue

                labels = list(map(int, parts[1:self.num_classes + 1]))
                self.samples.append((image_name, labels))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        image_name, label = self.samples[idx]
        image_path = os.path.join(self.image_dir, image_name)

        image = Image.open(image_path).convert("RGB")
        image = self.transform(image)

        label = torch.tensor(label, dtype=torch.int)

        return image, label

**Note**: The validation dataset is not for tuning hyperparameters, but to measure the training loss. Cross-validation is not used.

In [7]:
# Credit: https://github.com/nyuad-cai/CXR-ML-GZSL/blob/master/ChexnetTrainer.py#L104-L130
# Credit: https://github.com/nyuad-cai/CXR-ML-GZSL/blob/master/arguments.py#L22-L23

normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

training_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    normalize
])

testing_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.TenCrop(224),
    transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])),
    transforms.Lambda(lambda crops: torch.stack([normalize(crop) for crop in crops]))
])

train_data = ChestXrayDataset(IMAGE_PATH, "train.txt", num_classes=10, transform=training_transform)
val_data   = ChestXrayDataset(IMAGE_PATH, "val.txt",   num_classes=10, transform=testing_transform)
test_data  = ChestXrayDataset(IMAGE_PATH, "test.txt",  num_classes=14, transform=testing_transform)

print(f"Training samples: {len(train_data)}")
print(f"Validation samples: {len(val_data)}")
print(f"Testing samples: {len(test_data)}")

# assert len(train_data) == 30758, "Training dataset is not the right size!" #todo: Add once we add the full database
# assert len(val_data) == 4474, "Validation dataset is not the right size!"
# assert len(test_data) == 10510, "Test dataset is not the right size!"

assert multiprocessing.cpu_count() > 10, f"CPU only has {multiprocessing.cpu_count()} cores"

train_loader = DataLoader(train_data, batch_size=16,    shuffle=True,  num_workers=10, pin_memory=True)
val_loader   = DataLoader(val_data,   batch_size=16*10, shuffle=False, num_workers=10, pin_memory=True)
test_loader  = DataLoader(test_data,  batch_size=16*3,  shuffle=False, num_workers=10, pin_memory=True)

Training samples: 3544
Validation samples: 410
Testing samples: 1045


## Model

## Training

## Evaluation

## Results