In [None]:
import os
import sys
import logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

import torch
import matplotlib.pyplot as plt
import numpy as np
import tqdm
from torchvision import transforms
import albumentations as A

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), os.pardir)))
import himyb.datasets.waterbirds as waterbirds

%load_ext autoreload
%autoreload 2

In [None]:
ROOT_DIR = ... #Â root dir of the Waterbirds dataset

In [None]:
img_avg = [0.485, 0.456, 0.406]
img_std = [0.229, 0.224, 0.225]

t_img_avg = torch.as_tensor(img_avg).view(1, 3, 1, 1)
t_img_std = torch.as_tensor(img_std).view(1, 3, 1, 1)

In [None]:
transform = transforms.Compose(
            [
                transforms.Resize((64,64)),
                transforms.ToTensor(),
                transforms.Normalize(mean=img_avg, std=img_std),
            ]
        )

In [None]:
dataset = waterbirds.LargeWaterbirds(
    root=ROOT_DIR,
    transform=transform,
)

In [None]:
import logging
logger.setLevel(logging.DEBUG)
logging.debug("test")
dataset = waterbirds.get_balanced_waterbirds(
    root=ROOT_DIR,
    transform=transform,
    rho=0.70
)

In [None]:
(dataset[0]['image']).shape

In [None]:
# Plot the first 5 images in the dataset
fig, axes = plt.subplots(1, 5, figsize=(15, 5))
for i in range(5):
    img = dataset[i]['image']  # Get the image and label (if any)
    img = (img * t_img_std + t_img_avg).squeeze(0)  # Denormalize the image
    img = img.permute(1, 2, 0).numpy()  # Convert to HWC format for plotting
    axes[i].imshow(img)
    axes[i].axis('off')
plt.show()

## testing the balanced dataloader
It is possible to choose the rho value

In [None]:
loader = waterbirds.get_balanced_waterbirds_dataloader(
    root=ROOT_DIR,
    batch_size=32,
    num_workers=4,
    shuffle=True,
    rho=0.90
)

In [None]:
len(loader.dataset)

In [None]:
batch = next(iter(loader))
imgs = batch['image']
imgs.shape

In [None]:
batch["class_label"], batch["bias_label"]

In [None]:
# Plot the first 5 images in the batch
fig, axes = plt.subplots(1, 5, figsize=(15, 5))
for i in range(5):
    img = imgs[i]  # Get the image
    img = (img * t_img_std + t_img_avg).squeeze(0)  # Denormalize the image
    img = img.permute(1, 2, 0).numpy()  # Convert to HWC format for plotting
    axes[i].imshow(img)
    axes[i].axis('off')
plt.show()

In [None]:
n_aligned = 0
n_total = 0
for batch in tqdm.tqdm(loader):
    class_labels = batch["class_label"]
    bias_labels = batch["bias_label"]
    n_aligned += (class_labels == bias_labels).sum().item()
    n_total += len(class_labels)

In [None]:
n_aligned, n_total, n_aligned / n_total