In [None]:
import os
import random
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

In [None]:
task_path = "YOUR_PATH_TO_TASKS"
data_path = "YOUR_PATH_TO_DATASET"

data_name = "CelebA"
task_y = 8
task_a = 20
task_y_name = "Hair color"
task_a_name = "Gender"
data_size = 1000
sc = 0.82  # spurious correlation
ci = 0.41  # label shifts
ai = 0.33  # covariate shifts

task_path = os.path.join(task_path, data_name.lower(), f"tasks_y{task_y}_a{task_a}", f"task_{data_size}_sc{sc:.2f}_ci{ci:.2f}_ai{ai:.2f}")

In [None]:
metadata = pd.read_csv(os.path.join(task_path, "metadata.csv"))

In [None]:
# subsample metadata
num_samples = 500
sampled_metadata = metadata.sample(n=num_samples, random_state=42)
sampled_metadata = sampled_metadata[["filename", "y", "a"]]

In [None]:
# Create figure and axes
fig, ax = plt.subplots(figsize=(6, 6))
ax.set_xlim(0, 2)
ax.set_ylim(0, 2)
ax.axis('off')
ax.set_title(f"Data size: {num_samples}\nSpurious correlation: {sc}, label shift: {ci}, covariate shift: {ai}", fontsize=13)

# Number of images per group
n_per_group = 8

# Define group to plot coordinates mapping
group_coords = {
    (0, 0): (0, 1),  # Top-left
    (0, 1): (1, 1),  # Top-right
    (1, 0): (0, 0),  # Bottom-left
    (1, 1): (1, 0),  # Bottom-right
}

# Plot images in each group
for (y_val, a_val), (x_base, y_base) in group_coords.items():
    subset = sampled_metadata[(sampled_metadata['y'] == y_val) & (sampled_metadata['a'] == a_val)].sample(n=min(n_per_group, len(sampled_metadata)), random_state=1)

    for _, row in subset.iterrows():
        filename = row['filename']
        img_path = os.path.join(data_path, filename)
        try:
            img = mpimg.imread(img_path)
            imagebox = ax.inset_axes([
                x_base + random.uniform(0.1, 0.9) * 0.9,  # x
                y_base + random.uniform(0.1, 0.9) * 0.9,  # y
                0.20, 0.20                                 # width, height
            ], transform=ax.transData)
            imagebox.imshow(img)
            imagebox.axis('off')
        except FileNotFoundError:
            print(f"Missing: {row['filename']}")

# Add image counts per quadrant
for (y_val, a_val), (x_base, y_base) in group_coords.items():
    count = len(sampled_metadata[(sampled_metadata['y'] == y_val) & (sampled_metadata['a'] == a_val)])
    ax.text(
        x_base + 0.05, y_base + 0.05,
        f"count: {count}",
        fontsize=12,
        color='darkred'
    )

# Draw vertical and horizontal divider lines
ax.axvline(x=1, color='black', linewidth=2, linestyle='--')
ax.axhline(y=1, color='black', linewidth=2, linestyle='--')

fig.text(
    0.5,           # x-position (0 = far left, 1 = far right)
    0.0,          # y-position (closer to 0 is near bottom)
    f"{task_y_name} v.s. {task_a_name}",
    ha='center',   # horizontal alignment
    fontsize=11
)

plt.tight_layout()

# save as high-resolution image in current directory
plt.savefig("../assets/task1.png", dpi=300, bbox_inches='tight')