In [1]:
import os
import pandas as pd
from dotenv import load_dotenv
from src.prePro import preprocess_metadata, calculate_balanced_label_statistics, stratified_split_by_individual_labels
from src.utils import load_image, compute_mean_std
from src.XRayDataset import XRayDataset
from torch.utils.data import DataLoader

# Load environment variables from .env file
load_dotenv()
data_dir = os.getenv('DATA_DIR')

filtered_df = preprocess_metadata(
    f'../{data_dir}/raw/xraysMD.csv',
    f'../{data_dir}/raw/xrays',
    f'../{data_dir}/processed/xraysMD.csv'
)

filtered_df_stats = calculate_balanced_label_statistics(filtered_df)



In [3]:
train_df, test_df = stratified_split_by_individual_labels(filtered_df, train_size=7000, test_size=3000)

# Print sizes of the resulting DataFrames
print(f"Training set size: {len(train_df)}")
print(f"Test/Validation set size: {len(test_df)}")

# Calculate the label distribution for the training and test sets
train_distribution = pd.Series([label for labels in train_df['Labels'] for label in labels]).value_counts()
test_distribution = pd.Series([label for labels in test_df['Labels'] for label in labels]).value_counts()

# Combine both distributions into a DataFrame
statistics_df = pd.DataFrame({
    'Training': train_distribution,
    'Test/Validation': test_distribution
})

# Fill NaN values with 0 (in case a label is not present in either set)
statistics_df.fillna(0, inplace=True)

# Add a total row to both columns
statistics_df.loc['Total'] = statistics_df.sum()

print("\nCombined label distribution statistics:")
print(statistics_df)


Training set size: 7000
Test/Validation set size: 3000

Combined label distribution statistics:
                    Training  Test/Validation
Atelectasis              667              265
Cardiomegaly             199               93
Consolidation            236              120
Edema                     93               32
Effusion                 685              310
Emphysema                154               67
Fibrosis                 144               75
Hernia                    16                9
Infiltration            1044              395
Mass                     259              106
No Finding              4048             1755
Nodule                   373              151
Pleural_Thickening       188              101
Pneumonia                 78               32
Pneumothorax             283              114
Total                   8467             3625


In [3]:
# Compute mean and std using a list of image paths
image_paths = [f"../data/raw/xrays/{image_id}.png" for image_id in train_df['ImageID']]
img_size = 1000  # Set to your desired size

mean, std = compute_mean_std(image_paths, img_size)

print(f"Computed Mean: {mean}, Computed Std: {std}")


Computed Mean: 139.45093172971428, Computed Std: 61.9332860849686


In [8]:
mean = 139.45
std = 61.93
img_size = 1000
# Create training and validation datasets using the computed mean and std
train_dataset = XRayDataset(dataframe=train_df, image_dir='../data/raw/xrays/', img_size=img_size, mean=mean, std=std)
val_dataset = XRayDataset(dataframe=test_df, image_dir='../data/raw/xrays/', img_size=img_size, mean=mean, std=std)

# Create DataLoaders for batching and shuffling
batch_size = 16

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=1)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=1)


In [10]:
# Testing the DataLoader
data_iter = iter(train_loader)
images, labels = next(data_iter)

# Print the shapes of the batch
print(f"Images batch shape: {images.shape}")  # Should be [batch_size, channels, height, width]
print(f"Labels batch shape: {labels.shape}")  # Should be [batch_size, num_labels]

# Check the individual data points (optional)
print(f"First image shape: {images[0].shape}")
print(f"First label: {labels[0]}")

Images batch shape: torch.Size([16, 1, 1000, 1000])
Labels batch shape: torch.Size([16, 15])
First image shape: torch.Size([1, 1000, 1000])
First label: tensor([0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
