In [None]:
import pandas as pd
from torch.utils.data import DataLoader
from dataset_class import MessidorOpenCVDataset
from preprocess_class import OpenCV_DR_Preprocessor
import matplotlib.pyplot as plt
from transforms import light_transform, heavy_transform


preprocessor = OpenCV_DR_Preprocessor(apply_clahe=True)

# Create dataset
dataset = MessidorOpenCVDataset(root_dir='/Users/abohane/Desktop/THEIA Training/MESSIDOR',
                                preprocessor=preprocessor,
                                light_transform=None,
                                heavy_transform=None,
                                minority_classes=[3])

# Create dataloader
dataloader = DataLoader(dataset,
                        batch_size=32,
                        shuffle=True,
                        num_workers=4,
                        pin_memory=True)


In [None]:
# Check one batch
for images, labels in dataloader:
    print(images.shape)  # torch.Size([32, 3, 224, 224])
    print(labels.shape)  # torch.Size([32])
    break


In [None]:
print(dataset.data.columns)  # See column names
print(len(dataset.data))     # Number of total samples

print(dataset.data['Retinopathy grade'].value_counts())



In [None]:
hist = dataset.data['Retinopathy grade'].value_counts().sort_index()
# print histogram
hist.plot(kind='bar')
plt.xlabel('Retinopathy Grade')
plt.ylabel('Count')
plt.xticks(rotation=0)
plt.show()

In [None]:
import matplotlib.pyplot as plt
import numpy as np

def show_batch_images(dataloader, dataset, n_images=8):
    """
    Plots a few images from the dataloader along with their filename and label.

    Args:
        dataloader: PyTorch DataLoader
        dataset: Dataset object (to get filenames)
        n_images: How many images to show
    """
    # Get one batch
    images, labels = next(iter(dataloader))

    # Undo normalization (ImageNet stats)
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])

    # Create a figure
    plt.figure(figsize=(20, 8))

    for idx in range(n_images):
        img = images[idx]

        # Unnormalize
        img = img.permute(1, 2, 0).cpu().numpy()  # (C, H, W) -> (H, W, C)
        img = (img * std) + mean  # De-normalize
        img = np.clip(img, 0, 1)

        # Find filename and label
        dataset_idx = idx  #shuffled DataLoader, this matches first batch samples
        if hasattr(dataloader.dataset, 'data'):
            info = dataset.data.iloc[dataset_idx]
            filename = info['Image name']
            grade = info['Retinopathy grade']
        else:
            filename = "Unknown"
            grade = labels[idx].item()

        # Plot
        plt.subplot(2, n_images//2, idx + 1)
        plt.imshow(img)
        plt.axis('off')
        plt.title(f"{filename}\nGrade: {grade}", fontsize=10)

    plt.tight_layout()
    plt.show()


In [None]:
show_batch_images(dataloader, dataset, n_images=8)
