In [49]:
from datasets import load_dataset
import pandas as pd
import os
import random

In [50]:
# load WikiArt
ds = load_dataset("huggan/wikiart", split="train")
ds_meta = ds.remove_columns(["image"])

# filter out rows with "Unknown Artist" or "Unknown Genre" class
df = ds_meta.to_pandas()

In [51]:
artist_classes = ds.features["artist"].names
for idx, name in enumerate(artist_classes):
    print(idx, name)

0 Unknown Artist
1 boris-kustodiev
2 camille-pissarro
3 childe-hassam
4 claude-monet
5 edgar-degas
6 eugene-boudin
7 gustave-dore
8 ilya-repin
9 ivan-aivazovsky
10 ivan-shishkin
11 john-singer-sargent
12 marc-chagall
13 martiros-saryan
14 nicholas-roerich
15 pablo-picasso
16 paul-cezanne
17 pierre-auguste-renoir
18 pyotr-konchalovsky
19 raphael-kirchner
20 rembrandt
21 salvador-dali
22 vincent-van-gogh
23 hieronymus-bosch
24 leonardo-da-vinci
25 albrecht-durer
26 edouard-cortes
27 sam-francis
28 juan-gris
29 lucas-cranach-the-elder
30 paul-gauguin
31 konstantin-makovsky
32 egon-schiele
33 thomas-eakins
34 gustave-moreau
35 francisco-goya
36 edvard-munch
37 henri-matisse
38 fra-angelico
39 maxime-maufra
40 jan-matejko
41 mstislav-dobuzhinsky
42 alfred-sisley
43 mary-cassatt
44 gustave-loiseau
45 fernando-botero
46 zinaida-serebriakova
47 georges-seurat
48 isaac-levitan
49 joaquã­n-sorolla
50 jacek-malczewski
51 berthe-morisot
52 andy-warhol
53 arkhip-kuindzhi
54 niko-pirosmani
55 james-

In [52]:
genre_classes = ds.features["genre"].names
for idx, name in enumerate(genre_classes):
    print(idx, name) 

0 abstract_painting
1 cityscape
2 genre_painting
3 illustration
4 landscape
5 nude_painting
6 portrait
7 religious_painting
8 sketch_and_study
9 still_life
10 Unknown Genre


In [53]:
style_classes = ds.features["style"].names
for idx, name in enumerate(style_classes):
    print(idx, name) 

0 Abstract_Expressionism
1 Action_painting
2 Analytical_Cubism
3 Art_Nouveau
4 Baroque
5 Color_Field_Painting
6 Contemporary_Realism
7 Cubism
8 Early_Renaissance
9 Expressionism
10 Fauvism
11 High_Renaissance
12 Impressionism
13 Mannerism_Late_Renaissance
14 Minimalism
15 Naive_Art_Primitivism
16 New_Realism
17 Northern_Renaissance
18 Pointillism
19 Pop_Art
20 Post_Impressionism
21 Realism
22 Rococo
23 Romanticism
24 Symbolism
25 Synthetic_Cubism
26 Ukiyo_e


In [54]:
count_null_artist = df[df["artist"] == 0].shape[0]
count_null_genre = df[df["genre"] == 10].shape[0]
print("Rows with \"Unknown Artist\":", count_null_artist)
print("Rows with \"Unknown Genre\":", count_null_genre)

Rows with "Unknown Artist": 41914
Rows with "Unknown Genre": 16452


In [55]:
filtered_df = df[(df["artist"] != 0) & (df["genre"] != 10)]

# Group by "artist", "genre", "style"
grouped = (
    filtered_df.groupby(["artist", "genre", "style"])
      .size()
      .reset_index(name="count")
      .sort_values("count", ascending=False)
)

In [59]:
print("=== Top 20 Groups (Artist|Genre|Style) Counts ===")
print(grouped.head(20).to_string(index=False))

=== Top 20 Groups (Artist|Genre|Style) Counts ===
 artist  genre  style  count
      4      4     12    752
     14      4     24    716
      7      3     23    527
     17      6     12    464
     10      4     21    436
     22      8     21    398
     42      4     12    388
     22      8     20    369
      2      4     12    340
     11      6     21    335
     25      7     17    333
      8      6     21    332
     48      4     21    328
      5      2     12    315
     74      4     21    298
    104      6      9    282
     12      7     15    281
     17      4     12    252
     22      4     20    252
      4      1     12    252


In [58]:
GROUPS = [
    (4, 4, 12),
    (14, 4, 24),
    (7, 3, 23),
    (17, 6, 12),
    (10, 4, 21),
]


root = "wikiart_groups"
os.makedirs(root, exist_ok=True)

for artist_id, genre_id, style_id in GROUPS:

    mask = (
        (df["artist"] == artist_id) &
        (df["genre"] == genre_id) &
        (df["style"] == style_id)
    )
    indices = df[mask].index.tolist()

    if len(indices) == 0:
        print(f"❌ Group {artist_id}-{genre_id}-{style_id}: No images found.")
        continue

    selected = random.sample(indices, k=min(5, len(indices)))

    out_dir = os.path.join(root, f"group_{artist_id}_{genre_id}_{style_id}")
    os.makedirs(out_dir, exist_ok=True)

    for img_idx in selected:
        img = ds[img_idx]["image"]   # PIL Image
        img.save(os.path.join(out_dir, f"{img_idx}.jpg"))

    print(f"✅ Saved {len(selected)} images → {out_dir}")


✅ Saved 5 images → wikiart_groups\group_4_4_12
✅ Saved 5 images → wikiart_groups\group_14_4_24
✅ Saved 5 images → wikiart_groups\group_7_3_23
✅ Saved 5 images → wikiart_groups\group_17_6_12
✅ Saved 5 images → wikiart_groups\group_10_4_21


In [62]:
df = pd.read_csv("data/test_refs.csv")
sample_paths = df["file_path"].sample(10, random_state=42)

print("\n".join(sample_paths))

./data/raw_images/15_17459.jpg
./data/raw_images/12_1830.jpg
./data/raw_images/21_8573.jpg
./data/raw_images/3_13143.jpg
./data/raw_images/24_1689.jpg
./data/raw_images/20_8314.jpg
./data/raw_images/7_10278.jpg
./data/raw_images/21_12412.jpg
./data/raw_images/9_13187.jpg
./data/raw_images/4_6529.jpg
