In [23]:
import os
from pathlib import Path
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
import pandas as pd
from PIL import Image

In [24]:
class CXR8Dataset(Dataset):
    def __init__(self, csv_path, image_dir, transform=None):
        """
        Args:
            csv_path (str): Path to the CSV file.
            image_dir (str): Directory containing the images.
            transform (callable, optional): Transform to apply to images.
        """
        self.data = pd.read_csv(csv_path)
        self.image_dir = image_dir
        self.transform = transform

        available_images = set(os.listdir(self.image_dir))  # List of files in the directory
        self.data = self.data[self.data['id'].isin(available_images)]

        self.labels = self.data.drop(columns=['id', 'subject_id']).values  # Extract labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Get image ID and construct its full path
        img_id = self.data.iloc[idx]['id']
        img_path = os.path.join(self.image_dir, img_id)

        # Load the image
        image = Image.open(img_path).convert("RGB")

        # Apply transformations if provided
        if self.transform:
            image = self.transform(image)

        # Get the corresponding labels
        label = self.labels[idx]
        
        return image, label

In [32]:
csv_path = os.path.expanduser("~/datasets/CXR8/LongTailCXR/nih-cxr-lt_single-label_train.csv")
image_dir = os.path.expanduser("~/datasets/CXR8/images/images_001/images/")
batch_size = 32
image_scale = (224, 224)

In [33]:
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),  # Convert image to grayscale (1 channel)
    transforms.Resize(image_scale),               # Resize images to a uniform size
    transforms.ToTensor(),                       # Convert images to PyTorch tensors
    transforms.Normalize(mean=[0.5], std=[0.5])  # Normalize for grayscale (mean and std for a single channel)
])

In [34]:
# Load dataset
dataset = CXR8Dataset(csv_path=csv_path, image_dir=image_dir, transform=transform)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)


In [35]:
print(f"Number of images in the filtered dataset: {len(dataset)}")

# Test the DataLoader
for images, labels in data_loader:
    print(f"Batch of images: {images.shape}")
    print(f"Batch of labels: {labels.shape}")
    break

Number of images in the filtered dataset: 2895
Batch of images: torch.Size([32, 1, 224, 224])
Batch of labels: torch.Size([32, 20])
