## EDA dataset

In [1]:
import torch
import matplotlib.pyplot as plt
import numpy as np 
from torch.utils.data import DataLoader
from dataset_ACDC import ACDCdataset


train_data = ACDCdataset(base_dir='datasets/ACDC', list_dir='datasets/ACDC/lists_ACDC', split='train')
train_loader = DataLoader(train_data, batch_size=2, shuffle=True, num_workers=4)

In [2]:
sample_batch = next(iter(train_loader))
images = sample_batch['image']
masks = sample_batch['label']

print(f'Image shape: {images.shape}')
print(f'Mask shape: {masks.shape}')

Image shape: torch.Size([2, 1, 256, 256])
Mask shape: torch.Size([2, 256, 256])


In [7]:
# Check the channels/dimensions of all images and labels in the training set
def check_dataset_dimensions(dataloader):
    """
    Check dimensions and consistency of all samples in a dataloader.
    
    Parameters:
    -----------
    dataloader : torch.utils.data.DataLoader
        The dataloader to check
        
    Returns:
    --------
    dict
        Statistics about the dataset dimensions
    """
    image_shapes = []
    mask_shapes = []
    unique_classes = set()
    sample_count = 0
    
    print("Analyzing dataset dimensions...")
    
    for batch in dataloader:
        images = batch['image']
        masks = batch['label']
        
        batch_size = images.shape[0]
        sample_count += batch_size
        
        # Track all shapes
        for i in range(batch_size):
            image_shapes.append(images[i].shape)
            mask_shapes.append(masks[i].shape)
            unique_classes.update(np.unique(masks[i].numpy()).tolist())
    
    # Count and report unique shapes
    unique_img_shapes = set(image_shapes)
    unique_mask_shapes = set(mask_shapes)
    
    results = {
        "total_samples": sample_count,
        "unique_image_shapes": unique_img_shapes,
        "unique_mask_shapes": unique_mask_shapes,
        "unique_classes": sorted(list(unique_classes))
    }
    
    # Print summary
    print(f"Dataset analysis complete. Found {sample_count} samples.")
    print(f"Image shapes: {unique_img_shapes}")
    print(f"Mask shapes: {unique_mask_shapes}")
    print(f"Unique class labels: {sorted(list(unique_classes))}")
    
    if len(unique_img_shapes) == 1 and len(unique_mask_shapes) == 1:
        print("✓ CONSISTENT: All images and masks have consistent shapes.")
    else:
        print("⚠ INCONSISTENT: Found multiple image or mask shapes.")
    
    return results

# Run the analysis on the training dataset
results = check_dataset_dimensions(train_loader)

Analyzing dataset dimensions...
Dataset analysis complete. Found 1304 samples.
Image shapes: {torch.Size([1, 256, 256])}
Mask shapes: {torch.Size([256, 256])}
Unique class labels: [0, 1, 2, 3]
✓ CONSISTENT: All images and masks have consistent shapes.


In [8]:
val_data = ACDCdataset(base_dir='datasets/ACDC', list_dir='datasets/ACDC/lists_ACDC', split='valid')
val_loader = DataLoader(val_data, batch_size=2, shuffle=True, num_workers=4)

# Run the analysis on the validation dataset
results_val = check_dataset_dimensions(val_loader)

Analyzing dataset dimensions...
Dataset analysis complete. Found 182 samples.
Image shapes: {torch.Size([1, 256, 256])}
Mask shapes: {torch.Size([256, 256])}
Unique class labels: [0, 1, 2, 3]
✓ CONSISTENT: All images and masks have consistent shapes.
