# CXR-ML-GZSL

## Overview

The goal of this notebook is to reproduce the findings of the paper, "Multi-Label Generalized Zero Shot Learning for the Classification of Disease in Chest Radiographs" with the help of an LLM. The paper provides code, which was a valuable reference during the development of this notebook.

* Paper: https://arxiv.org/abs/2107.06563
* Code: https://github.com/nyuad-cai/CXR-ML-GZSL

This notebook reuses the dataset training, validation, and test split provided by the original paper, as well as the pre-generated BioBERT text embeddings of the class names.

* Dataset split: https://github.com/nyuad-cai/CXR-ML-GZSL/tree/master/dataset_splits
* Class embeddings: https://github.com/nyuad-cai/CXR-ML-GZSL/blob/master/embeddings/nih_chest_xray_biobert.npy

The paper uses a dataset, initially known as `ChestX-ray8`, but then renamed to `ChestX-ray14` when the dataset was expanded from eight to fourteen distinct disease labels. The dataset contains 112120 labeled chest X-rays.

* Dataset paper: https://arxiv.org/abs/1705.02315
* Dataset: https://nihcc.app.box.com/v/ChestXray-NIHCC/folder/36938765345

The dataset provides an example of how to download the chext X-ray images and a spreadsheet mapping images to classification labels. However, the "dataset split" files already contain the image classification labels so this notebook does not use the spreadsheet.

* Download script: https://nihcc.app.box.com/v/ChestXray-NIHCC/file/371647823217
* Labels: https://nihcc.app.box.com/v/ChestXray-NIHCC/file/219760887468

**Note**: The dataset is ~42 GB. Expect significant download times.

## Environment

In [1]:
!if command -v nvidia-smi &> /dev/null; then nvidia-smi --query-gpu=name --format=csv,noheader; else echo 'No NVIDIA GPU detected'; fi
!python --version

No NVIDIA GPU detected
Python 3.11.12


In [2]:
import multiprocessing
import os
import requests
import tarfile
import urllib.request

import numpy as np

print(f"numpy: {np.__version__}")

import PIL
from PIL import Image

print(f"PIL: {PIL.__version__}")

import torch
from torch.utils.data import DataLoader, Dataset

print(f"torch: {torch.__version__}")

import torchvision
import torchvision.transforms as transforms

print(f"torchvision: {torchvision.__version__}")

numpy: 2.0.2
PIL: 11.1.0
torch: 2.6.0+cu124
torchvision: 0.21.0+cu124


## Download

In [3]:
for filename in ["train.txt", "val.txt", "test.txt"]:
    response = requests.get(f"https://raw.githubusercontent.com/nyuad-cai/CXR-ML-GZSL/master/dataset_splits/{filename}")

    with open(filename, "w") as f:
        f.write(response.text)

    print(f"Downloaded: {filename}")

Downloaded: train.txt
Downloaded: val.txt
Downloaded: test.txt


In [4]:
CLASS_EMBEDDINGS = "nih_chest_xray_biobert.npy"

response = requests.get(f"https://raw.githubusercontent.com/nyuad-cai/CXR-ML-GZSL/master/embeddings/{CLASS_EMBEDDINGS}")

with open(CLASS_EMBEDDINGS, "wb") as f:
    f.write(response.content)

print(f"Downloaded: {CLASS_EMBEDDINGS}")

class_embeddings = np.load(CLASS_EMBEDDINGS)

Downloaded: nih_chest_xray_biobert.npy


In [5]:
# dataset = [
#     {"filename": "images_001.tar.gz", "url": "https://nihcc.box.com/shared/static/vfk49d74nhbxq3nqjg0900w5nvkorp5c.gz"} #todo: Add the rest of the dataset
# ]

dataset = [
    {"filename": "images_001.tar.gz", "url": "https://nihcc.box.com/shared/static/vfk49d74nhbxq3nqjg0900w5nvkorp5c.gz"},
    {"filename": "images_002.tar.gz", "url": "https://nihcc.box.com/shared/static/i28rlmbvmfjbl8p2n3ril0pptcmcu9d1.gz"},
    {"filename": "images_003.tar.gz", "url": "https://nihcc.box.com/shared/static/f1t00wrtdk94satdfb9olcolqx20z2jp.gz"},
    {"filename": "images_004.tar.gz", "url": "https://nihcc.box.com/shared/static/0aowwzs5lhjrceb3qp67ahp0rd1l1etg.gz"},
    {"filename": "images_005.tar.gz", "url": "https://nihcc.box.com/shared/static/v5e3goj22zr6h8tzualxfsqlqaygfbsn.gz"},
    {"filename": "images_006.tar.gz", "url": "https://nihcc.box.com/shared/static/asi7ikud9jwnkrnkj99jnpfkjdes7l6l.gz"},
    {"filename": "images_007.tar.gz", "url": "https://nihcc.box.com/shared/static/jn1b4mw4n6lnh74ovmcjb8y48h8xj07n.gz"},
    {"filename": "images_008.tar.gz", "url": "https://nihcc.box.com/shared/static/tvpxmn7qyrgl0w8wfh9kqfjskv6nmm1j.gz"},
    {"filename": "images_009.tar.gz", "url": "https://nihcc.box.com/shared/static/upyy3ml7qdumlgk2rfcvlb9k6gvqq2pj.gz"},
    {"filename": "images_010.tar.gz", "url": "https://nihcc.box.com/shared/static/l6nilvfa9cg3s28tqv1qc1olm3gnz54p.gz"},
    {"filename": "images_011.tar.gz", "url": "https://nihcc.box.com/shared/static/hhq8fkdgvcari67vfhs7ppg2w6ni4jze.gz"},
    {"filename": "images_012.tar.gz", "url": "https://nihcc.box.com/shared/static/ioqwiy20ihqwyr8pf4c24eazhh281pbu.gz"},
]

for item in dataset:
    filename = item["filename"]
    url = item["url"]

    urllib.request.urlretrieve(url, filename)

    with tarfile.open(filename, "r:gz") as tar:
        tar.extractall()

    os.remove(filename)

    print(f"Downloaded and extracted: {filename}")

IMAGE_PATH = "images"
NUM_IMAGES = 112120

assert os.path.exists(IMAGE_PATH), "Dataset is not in the expected directory!"
assert len([f for f in os.listdir(IMAGE_PATH) if os.path.isfile(os.path.join(IMAGE_PATH, f))]) == NUM_IMAGES, "Dataset is not the expected size!"

Downloaded and extracted: images_001.tar.gz
Downloaded and extracted: images_002.tar.gz
Downloaded and extracted: images_003.tar.gz
Downloaded and extracted: images_004.tar.gz
Downloaded and extracted: images_005.tar.gz
Downloaded and extracted: images_006.tar.gz
Downloaded and extracted: images_007.tar.gz
Downloaded and extracted: images_008.tar.gz
Downloaded and extracted: images_009.tar.gz
Downloaded and extracted: images_010.tar.gz
Downloaded and extracted: images_011.tar.gz
Downloaded and extracted: images_012.tar.gz


## Preprocessing

In [13]:
class ChestXrayDataset(Dataset):
    def __init__(self, image_dir, labels_file, transform, excluded_classes=None):
        self.image_dir = image_dir
        self.transform = transform
        self.data = []

        with open(labels_file, 'r') as f:
            for line in f:
                parts = line.strip().split()
                image_name = os.path.basename(parts[0])
                labels = list(map(int, parts[1:]))

                if all(l == 0 for l in labels):
                    continue

                if excluded_classes is not None and any(labels[i] == 1 for i in excluded_classes):
                    continue

                # if not os.path.isfile(f"{self.image_dir}/{image_name}"): #todo: Remove once we add the full database
                #     continue

                self.data.append((image_name, torch.tensor(labels, dtype=torch.int64)))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image_name, labels = self.data[idx]
        image_path = os.path.join(self.image_dir, image_name)

        image = Image.open(image_path).convert('RGB')
        image = self.transform(image)

        return image, labels

**Note**: The validation dataset is not for tuning hyperparameters, but to measure the training loss. Cross-validation is not used.

In [15]:
# Credit: https://github.com/nyuad-cai/CXR-ML-GZSL/blob/master/ChexnetTrainer.py#L104-L130
# Credit: https://github.com/nyuad-cai/CXR-ML-GZSL/blob/master/arguments.py#L22-L23

normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

training_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    normalize
])

testing_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.TenCrop(224),
    transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])),
    transforms.Lambda(lambda crops: torch.stack([normalize(crop) for crop in crops]))
])

train_data = ChestXrayDataset(IMAGE_PATH, "train.txt", transform=training_transform, excluded_classes=[10, 11, 12, 13])
val_data   = ChestXrayDataset(IMAGE_PATH, "val.txt",   transform=testing_transform,  excluded_classes=[10, 11, 12, 13])
test_data  = ChestXrayDataset(IMAGE_PATH, "test.txt",  transform=testing_transform)

print(f"Training samples: {len(train_data)}")
print(f"Validation samples: {len(val_data)}")
print(f"Testing samples: {len(test_data)}")

train_loader = DataLoader(train_data, batch_size=16,    shuffle=True,  num_workers=multiprocessing.cpu_count(), pin_memory=True)
val_loader   = DataLoader(val_data,   batch_size=16*10, shuffle=False, num_workers=multiprocessing.cpu_count(), pin_memory=True)
test_loader  = DataLoader(test_data,  batch_size=16*3,  shuffle=False, num_workers=multiprocessing.cpu_count(), pin_memory=True)

Training samples: 30935
Validation samples: 4383
Testing samples: 10505


## Model

## Training

## Evaluation

## Results