In [None]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torchvision.transforms as T
from torch.utils.data import DataLoader
from torchvision.utils import make_grid

root_dir = "../"
sys.path.append(root_dir)
from configs import get_config
from datasets import get_dataset

config_name = "chexpert"
config = get_config(config_name)

transform = T.Compose([T.Resize(224), T.CenterCrop((224, 224)), T.ToTensor()])
dataset = get_dataset(config, transform=transform, train=True, return_attribute=True)
print(f"Dataset size: {len(dataset):,}")
print(f"Dataset classes ({len(dataset.classes)}): {dataset.classes}")
print(f"Dataset claims ({len(dataset.claims):,}): {dataset.claims}")

images_per_class = {label: 0 for label, _ in enumerate(dataset.classes)}
for _, label in dataset.samples:
    images_per_class[label] += 1
print(f"Image per class:")
for label, count in images_per_class.items():
    print(f"\t{dataset.classes[label]}: {count:,}")

sns.set_theme()
sns.set_context("paper")

In [None]:
m = 24
dataloader = DataLoader(dataset, batch_size=m, shuffle=True)
data = next(iter(dataloader))
image, _, image_attribute = data

_, ax = plt.subplots(figsize=(16, 9 / 2))
grid = make_grid(image, nrow=8)
ax.imshow(grid.permute(1, 2, 0))
ax.axis("off")
plt.show()

In [None]:
claims = dataset.claims
for _image_attribute in image_attribute:
    _positive_attribute = [
        claims[i] for i, label in enumerate(_image_attribute) if label == 1
    ]
    print(len(_positive_attribute) / len(claims), _positive_attribute)