In [1]:
import os
from pathlib import Path
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
import pandas as pd
from PIL import Image

In [7]:
import os
import tarfile
import shutil

def decompress_and_collect_images(source_dir):
    """Descomprime .tar.gz y copia todas las imágenes en images_total."""
    output_dir = os.path.join(source_dir, "images_total")
    os.makedirs(output_dir, exist_ok=True)
    
    for file_name in os.listdir(source_dir):
        if file_name.endswith(".tar.gz"):
            tar_path = os.path.join(source_dir, file_name)
            extract_path = os.path.join(source_dir, file_name[:-7])
            
            with tarfile.open(tar_path, "r:gz") as tar:
                tar.extractall(extract_path)
            
            for root, _, files in os.walk(extract_path):
                for f in files:
                    if f.lower().endswith((".png", ".jpg", ".jpeg")):
                        shutil.copy(os.path.join(root, f), output_dir)

In [8]:
decompress_and_collect_images("/home/lingfeng/datasets/CXR8/images")

In [2]:
class CXR8Dataset(Dataset):
    def __init__(self, csv_path, image_dir, transform=None):
        """
        Args:
            csv_path (str): Path to the CSV file.
            image_dir (str): Directory containing the images.
            transform (callable, optional): Transform to apply to images.
        """
        self.data = pd.read_csv(csv_path)
        self.image_dir = image_dir
        self.transform = transform

        available_images = set(os.listdir(self.image_dir))  # List of files in the directory
        self.data = self.data[self.data['id'].isin(available_images)]

        self.labels = self.data.drop(columns=['id', 'subject_id']).values  # Extract labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Get image ID and construct its full path
        img_id = self.data.iloc[idx]['id']
        img_path = os.path.join(self.image_dir, img_id)

        # Load the image
        image = Image.open(img_path).convert("RGB")

        # Apply transformations if provided
        if self.transform:
            image = self.transform(image)

        # Get the corresponding labels
        label = self.labels[idx]
        
        return image, label

In [3]:
def get_tar_gz_paths(paths:list, extension:str):
    return [p for p in paths if p.endswith(extension)]

In [4]:
csv_path = os.path.expanduser("~/datasets/CXR8/LongTailCXR/nih-cxr-lt_single-label_train.csv")
image_dir = os.path.expanduser("~/datasets/CXR8/images/")
batch_size = 32
image_scale = (224, 224)
get_tar_gz_paths(os.listdir(image_dir), ".tar.gz")

['images_005.tar.gz',
 'images_003.tar.gz',
 'images_006.tar.gz',
 'images_002.tar.gz',
 'images_007.tar.gz',
 'images_009.tar.gz',
 'images_004.tar.gz',
 'images_012.tar.gz',
 'images_008.tar.gz',
 'images_011.tar.gz',
 'images_010.tar.gz',
 'images_001.tar.gz']

In [5]:
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),  # Convert image to grayscale (1 channel)
    transforms.Resize(image_scale),               # Resize images to a uniform size
    transforms.ToTensor(),                       # Convert images to PyTorch tensors
    transforms.Normalize(mean=[0.5], std=[0.5])  # Normalize for grayscale (mean and std for a single channel)
])

In [6]:
# Load dataset
dataset = CXR8Dataset(csv_path=csv_path, image_dir=image_dir, transform=transform)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)


ValueError: num_samples should be a positive integer value, but got num_samples=0

In [None]:
print(f"Number of images in the filtered dataset: {len(dataset)}")

# Test the DataLoader
for images, labels in data_loader:
    print(f"Batch of images: {images.shape}")
    print(f"Batch of labels: {labels.shape}")
    break