In [1]:
import os
from dotenv import load_dotenv

from src.prePro import preprocess_metadata
from src.prePro import calculate_label_statistics

# Load environment variables from .env file
load_dotenv()
data_dir = os.getenv('DATA_DIR')

filtered_df = preprocess_metadata(
    f'../{data_dir}/raw/chestScansMD.csv',
    f'../{data_dir}/raw/x-ray_scans',
    f'../{data_dir}/processed/chestScansMD.csv'
)


stats = calculate_label_statistics(filtered_df)

# Convert the label statistics DataFrame to a dictionary for quick lookup
augmentation_factor_dict = stats.set_index('Label')['Augmentation_Factor'].to_dict()



Total number of images: 34999
Label statistics:
                 Label  Count  Percentage  Augmentation_Factor
0           No Finding  20353   58.153090                    0
1         Infiltration   5048   14.423269                    0
2             Effusion   3370    9.628847                    0
3          Atelectasis   3345    9.557416                    0
4               Nodule   1803    5.151576                    0
5         Pneumothorax   1406    4.017258                    0
6                 Mass   1332    3.805823                    0
7        Consolidation   1241    3.545816                    0
8   Pleural_Thickening   1018    2.908655                    1
9         Cardiomegaly    994    2.840081                    1
10           Emphysema    726    2.074345                    1
11            Fibrosis    721    2.060059                    1
12               Edema    497    1.420041                    2
13           Pneumonia    386    1.102889                    3
14     

In [None]:
from src.augement import save_augmented_images

# Loop through each row of the preprocessed DataFrame
for idx, row in filtered_df.iterrows():
    img_name = row['ImageID']
    disease_labels = row['Labels']  # This is already a list of diseases from preprocess_metadata
    img_path = f'../data/raw/x-ray_scans/{img_name}.png'
    
    # Calculate the maximum augmentation factor for the image based on its disease labels
    # This assumes that you have a DataFrame `label_stats_df` with 'Label' and 'Augmentation_Factor'
    max_augmentation_factor = stats[stats['Label'].isin(disease_labels)]['Augmentation_Factor'].max()
    
    # Call the function to save the augmented images with the calculated max factor
    save_augmented_images(img_path, '../data/processed/augmented_x-rays', img_name, disease_labels, max_augmentation_factor)


In [None]:
import pandas as pd

# Function to calculate augmented image statistics and copy augmented DataFrame
def calculate_augmented_image_statistics_and_copy(filtered_df, label_stats_df, augmented_dir):
    """
    Calculate the statistics for the augmented images and copy the corresponding
    rows from the original DataFrame into a new DataFrame for augmented images.

    Args:
        filtered_df (pd.DataFrame): The DataFrame containing the original images and their labels.
        label_stats_df (pd.DataFrame): The DataFrame containing label augmentation factors.
        augmented_dir (str): The directory where the augmented images are stored.
        
    Returns:
        pd.DataFrame: A DataFrame with statistics of augmented images, including total counts.
        pd.DataFrame: A new DataFrame with the augmented images' information.
    """
    import os
    
    # Get all the augmented image filenames from the directory
    augmented_images = [f for f in os.listdir(augmented_dir) if f.endswith('.png')]
    
    # Initialize a dictionary to track augmented counts for each label
    augmented_counts = {label: 0 for label in label_stats_df['Label']}
    
    # Initialize a list to store rows of augmented images
    augmented_rows = []
    
    # Loop through each augmented image
    for img in augmented_images:
        original_image_id = img[:12]  # First 12 characters as ImageID
        
        # Find the original record in the filtered DataFrame
        original_record = filtered_df[filtered_df['ImageID'] == original_image_id]
        
        if not original_record.empty:
            labels = original_record.iloc[0]['Labels']
            
            # Find the maximum augmentation factor among the labels
            max_factor = 0
            label_with_max_factor = None
            for label in labels:
                label_factor = label_stats_df[label_stats_df['Label'] == label]['Augmentation_Factor'].values[0]
                if label_factor > max_factor:
                    max_factor = label_factor
                    label_with_max_factor = label
            
            # Increment the label with the highest augmentation factor
            if label_with_max_factor:
                augmented_counts[label_with_max_factor] += 1
            
            # Copy the original row to the augmented rows list with updated information
            augmented_row = original_record.copy()
            augmented_row['Augmented_ImageID'] = img  # Add the augmented image name
            augmented_row['Augmentation_Factor'] = max_factor  # Add the highest augmentation factor
            augmented_rows.append(augmented_row)
    
    # Convert the augmented rows list into a new DataFrame
    augmented_df = pd.concat(augmented_rows, ignore_index=True)
    
    # Convert the augmented counts to a DataFrame
    augmented_counts_df = pd.DataFrame(list(augmented_counts.items()), columns=['Label', 'Augmented_Count'])
    
    # Add original counts from filtered_df for comparison
    original_label_counts = filtered_df['Labels'].explode().value_counts().reset_index()
    original_label_counts.columns = ['Label', 'Original_Count']
    
    # Merge both DataFrames to show original vs augmented stats
    stats_df = pd.merge(augmented_counts_df, original_label_counts, on='Label', how='left')
    
    # Fill missing original counts with 0 (for labels that only appear in augmented images)
    stats_df['Original_Count'] = stats_df['Original_Count'].fillna(0)
    
    # Add total column (sum of original and augmented counts)
    stats_df['Total_Count'] = stats_df['Augmented_Count'] + stats_df['Original_Count']
    
    # Add a total row for both original and augmented counts
    total_row = pd.DataFrame({
        'Label': ['Total'],
        'Augmented_Count': [stats_df['Augmented_Count'].sum()],
        'Original_Count': [stats_df['Original_Count'].sum()],
        'Total_Count': [stats_df['Total_Count'].sum()]
    })
    
    # Append the total row to the stats DataFrame
    stats_df = pd.concat([stats_df, total_row], ignore_index=True)
    
    return stats_df, augmented_df

# Example usage:
# Assuming `filtered_df` is your preprocessed DataFrame, `label_stats_df` contains the augmentation factors,
# and augmented images are stored in '../data/processed/augmented_x-rays'
augmented_stats_df, augmented_images_df = calculate_augmented_image_statistics_and_copy(filtered_df, stats, '../data/processed/augmented_x-rays')

# Display the statistics and augmented images DataFrame
augmented_stats_df, augmented_images_df.head()


In [None]:
import torch

# Check if CUDA (GPU) is available
if torch.cuda.is_available():
    device = torch.device("cuda")  # Use GPU
    torch.cuda.empty_cache() # empty cache
    print('CUDA available')

In [7]:
import pandas as pd
import os
import shutil  # To move files

# Function to calculate augmented image statistics, move images to subfolders, and ensure no repetitive entries
def calculate_augmented_image_statistics_no_repeats(filtered_df, label_stats_df, augmented_dir):
    """
    Calculate the statistics for the augmented images, excluding repetitive entries
    from the label_stats_df (no repeated ImageID entries). Additionally, move augmented
    images to subfolders for specific diseases based on the highest Augmentation_Factor.

    Args:
        filtered_df (pd.DataFrame): The DataFrame containing the original images and their labels.
        label_stats_df (pd.DataFrame): The DataFrame containing label augmentation factors.
        augmented_dir (str): The directory where the augmented images are stored.
        
    Returns:
        pd.DataFrame: A DataFrame with statistics of augmented images, including total counts.
        pd.DataFrame: A DataFrame containing only the original non-repetitive entries from label_stats_df.
    """
    
    # Get all the augmented image filenames from the directory
    augmented_images = [f for f in os.listdir(augmented_dir) if f.endswith('.png')]
    
    # Initialize a dictionary to track augmented counts for each label
    augmented_counts = {label: 0 for label in label_stats_df['Label']}
    
    # Initialize a set to track unique ImageIDs
    seen_image_ids = set()
    
    # Initialize a list to store unique rows from label_stats_df
    unique_label_stats_rows = []
    
    # Define the target labels and their subfolders
    target_labels = ['Hernia', 'Pneumonia', 'Edema', 'Fibrosis', 'Emphysema']
    
    # Ensure subfolders exist for each of the target labels
    for label in target_labels:
        label_folder = os.path.join(augmented_dir, label)
        os.makedirs(label_folder, exist_ok=True)
    
    # Loop through each augmented image
    for img in augmented_images:
        original_image_id = img[:12]  # First 12 characters as ImageID
        
        # Find the original record in the filtered DataFrame
        original_record = filtered_df[filtered_df['ImageID'] == original_image_id]
        
        if not original_record.empty:
            labels = original_record.iloc[0]['Labels']
            
            # Find the maximum augmentation factor among the labels
            max_factor = 0
            label_with_max_factor = None
            for label in labels:
                label_factor = label_stats_df[label_stats_df['Label'] == label]['Augmentation_Factor'].values[0]
                if label_factor > max_factor:
                    max_factor = label_factor
                    label_with_max_factor = label
            
            # Increment the label with the highest augmentation factor
            if label_with_max_factor:
                augmented_counts[label_with_max_factor] += 1
            
            # Add the original record to the list if it's not already in seen_image_ids
            if original_image_id not in seen_image_ids:
                unique_label_stats_rows.append(original_record)
                seen_image_ids.add(original_image_id)
            
            # Move image to the appropriate subfolder based on the highest augmentation factor
            # Only move if the label with the max factor is in target_labels
            if label_with_max_factor in target_labels:
                source_path = os.path.join(augmented_dir, img)
                destination_folder = os.path.join(augmented_dir, label_with_max_factor)
                destination_path = os.path.join(destination_folder, img)
                
                # Move the file to the appropriate subfolder
                shutil.move(source_path, destination_path)
    
    # Convert the augmented counts to a DataFrame
    augmented_counts_df = pd.DataFrame(list(augmented_counts.items()), columns=['Label', 'Augmented_Count'])
    
    # Add original counts from filtered_df for comparison
    original_label_counts = filtered_df['Labels'].explode().value_counts().reset_index()
    original_label_counts.columns = ['Label', 'Original_Count']
    
    # Merge both DataFrames to show original vs augmented stats
    stats_df = pd.merge(augmented_counts_df, original_label_counts, on='Label', how='left')
    
    # Fill missing original counts with 0 (for labels that only appear in augmented images)
    stats_df['Original_Count'] = stats_df['Original_Count'].fillna(0)
    
    # Add total column (sum of original and augmented counts)
    stats_df['Total_Count'] = stats_df['Augmented_Count'] + stats_df['Original_Count']
    
    # Add a total row for both original and augmented counts
    total_row = pd.DataFrame({
        'Label': ['Total'],
        'Augmented_Count': [stats_df['Augmented_Count'].sum()],
        'Original_Count': [stats_df['Original_Count'].sum()],
        'Total_Count': [stats_df['Total_Count'].sum()]
    })
    
    # Append the total row to the stats DataFrame
    stats_df = pd.concat([stats_df, total_row], ignore_index=True)
    
    # Convert the unique rows into a new DataFrame (ensuring unique ImageIDs)
    unique_label_stats_df = pd.concat(unique_label_stats_rows, ignore_index=True)
    
    return stats_df, unique_label_stats_df

# Example usage:
# Assuming filtered_df is your preprocessed DataFrame, label_stats_df contains the augmentation factors,
# and augmented images are stored in '../data/processed/augmented_x-rays'
augmented_stats_df, unique_label_stats_df = calculate_augmented_image_statistics_no_repeats(filtered_df, stats, '../data/processed/augmented_x-rays')

# Display the statistics DataFrame and the unique label stats DataFrame
augmented_stats_df, unique_label_stats_df


(                 Label  Augmented_Count  Original_Count  Total_Count
 0           No Finding                0           20353        20353
 1         Infiltration                0            5048         5048
 2             Effusion                0            3370         3370
 3          Atelectasis                0            3345         3345
 4               Nodule                0            1803         1803
 5         Pneumothorax                0            1406         1406
 6                 Mass                0            1332         1332
 7        Consolidation                0            1241         1241
 8   Pleural_Thickening                0            1018         1018
 9         Cardiomegaly               48             994         1042
 10           Emphysema              686             726         1412
 11            Fibrosis              667             721         1388
 12               Edema              862             497         1359
 13           Pneumo

In [None]:
from torch.utils.data import DataLoader, Dataset
import torch
from PIL import Image

def load_image(image_path):
    img = Image.open(image_path).convert('L')  # Convert to grayscale
    img_resized = img.resize((300, 300))  # Resize to 300x300 for EfficientNet B2 later to be adjusted
    return img_resized

class ChestXrayDataset(Dataset):
    def __init__(self, metadata_df, transform=None):
        self.metadata_df = metadata_df
        self.transform = transform

    def __len__(self):
        return len(self.metadata_df)

    def __getitem__(self, idx):
        img_name = self.metadata_df.iloc[idx, 0]
        img_path = f'data/origin/x-ray_scans/{img_name}.png'
        image = load_image(img_path)
        # Ensure the labels are processed into a multi-hot encoded tensor
        labels = torch.tensor(self.metadata_df.iloc[idx, -1], dtype=torch.float32)  # Ensure labels are tensors
        disease_labels = self.metadata_df.iloc[idx, 1]  # String of disease labels
        transform = get_transform(disease_labels)  # Apply conditional transformations
        image = transform(image)
        # Additional features
        age = self.metadata_df.iloc[idx, 2]
        gender = self.metadata_df.iloc[idx, 3]
        xray_view = self.metadata_df.iloc[idx, 4]
        return image, labels, torch.tensor([age, gender, xray_view])

# Create the DataLoader
dataset = ChestXrayDataset(metadata_df)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=0)

# Testing the DataLoader
data_iter = iter(dataloader)
images, labels, metadata = next(data_iter)

# Now you can check the shape of the tensors
print(f"Batch of Images: {images.shape}")  # Should be [batch_size, channels, height, width]
print(f"Batch of Labels: {labels.shape}")  # Should be [batch_size, num_classes]
print(f"Batch of Metadata: {metadata.shape}")  # Should be [batch_size, 3] for age, gender, xray_view

# Visualize the first image and metadata in the batch
plt.imshow(images[0].cpu().numpy().transpose(1, 2, 0), cmap='gray')  # Display as grayscale image
plt.title(f"Labels: {labels[0]} | Metadata: {metadata[0]}")
plt.show()