In [25]:
import os
import pandas as pd
import numpy as np
from tqdm import tqdm
import torch
from torch.utils.data import Dataset
from torchvision import datasets, transforms
from torchvision.io import read_image

class XRayImageDataset(Dataset):
    
    def __init__(self, annotations_file, img_dir, transform_norm=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.mean, self.std = self.compute_mean_std()
        self.transform_norm = transforms.Compose([
            transforms.Normalize(self.mean, self.std)
        ])
        self.target_transform = target_transform

    def compute_mean_std(self):
        means = 0
        stds = 0
        for filename in tqdm(os.listdir(self.img_dir)):
            mean, std = torch.std_mean(read_image(os.path.join(self.img_dir, filename)) / 255)
            means += mean.numpy()
            stds += std.numpy()
        mean = means / self.__len__()
        std = stds / self.__len__()
        return mean, std
        

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path) # PyTorch function, no need to change
        label = self.img_labels.iloc[idx, 2] # class_id column
        if self.transform_norm:
            image = self.transform_norm(image)
        if self.target_transform:
            label = self.target_transform(label)
        return norm_img, label

In [21]:
ROOT = "/kaggle/input/amia-public-challenge-2024"

test_img_path = ROOT + "/test/test"
train_img_path = ROOT + "/train/train"

test_annot_path = ROOT + "/test.csv"
train_annot_path = ROOT + "/train.csv"

In [26]:
from torch.utils.data import DataLoader

train_data = XRayImageDataset(train_annot_path, train_img_path)
print(f'Training images: mean {train_data.mean}, std {train_data.std}')

train_dataloader = DataLoader(train_data, batch_size=64, shuffle=True)

/kaggle/input/amia-public-challenge-2024/test/test


100%|██████████| 8573/8573 [02:38<00:00, 54.21it/s]

Training images: mean 0.04664242120787185, std 0.10213025072799406





In [27]:
test_data = XRayImageDataset(test_annot_path, test_img_path)
print(f'Testing images: mean {test_data.mean}, std {test_data.std}')

test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

100%|██████████| 6427/6427 [02:50<00:00, 37.77it/s]

Testing images: mean 0.07274098403775907, std 0.16118353533641264



