# DeepArtNet — Data Exploration

This notebook explores the WikiArt dataset: class distributions, image counts, and sample visualizations.

In [None]:
import sys
from pathlib import Path

# Add project root to path
ROOT = Path().resolve().parent
sys.path.insert(0, str(ROOT))

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from collections import Counter

DATA_DIR = ROOT / 'data' / 'wikiart'

## 1. Load CSVs and Class Files

In [None]:
def load_class_names(class_file: Path):
    lines = class_file.read_text().strip().splitlines()
    return {int(l.split()[0]): l.split()[1] for l in lines if l.strip()}

style_classes  = load_class_names(DATA_DIR / 'style_class.txt')
genre_classes  = load_class_names(DATA_DIR / 'genre_class.txt')
artist_classes = load_class_names(DATA_DIR / 'artist_class.txt')

print(f'Styles:  {len(style_classes)}')
print(f'Genres:  {len(genre_classes)}')
print(f'Artists: {len(artist_classes)}')

In [None]:
def load_csv(path: Path, label_col: str):
    df = pd.read_csv(path, header=None, names=['path', label_col])
    return df

style_train  = load_csv(DATA_DIR / 'style_train.csv',  'style')
style_val    = load_csv(DATA_DIR / 'style_val.csv',    'style')
genre_train  = load_csv(DATA_DIR / 'genre_train.csv',  'genre')
genre_val    = load_csv(DATA_DIR / 'genre_val.csv',    'genre')
artist_train = load_csv(DATA_DIR / 'artist_train.csv', 'artist')
artist_val   = load_csv(DATA_DIR / 'artist_val.csv',   'artist')

print('Train sizes:', len(style_train), len(genre_train), len(artist_train))
print('Val sizes:  ', len(style_val),   len(genre_val),   len(artist_val))

## 2. Class Distribution Plots

In [None]:
def plot_distribution(df, label_col, class_map, title, figsize=(14, 5)):
    counts = df[label_col].value_counts().sort_index()
    labels = [class_map.get(i, str(i)) for i in counts.index]
    fig, ax = plt.subplots(figsize=figsize)
    bars = ax.bar(labels, counts.values, color='steelblue', edgecolor='white')
    ax.set_title(title, fontsize=14, fontweight='bold')
    ax.set_ylabel('Count')
    ax.tick_params(axis='x', rotation=60)
    for bar, val in zip(bars, counts.values):
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 30,
                str(val), ha='center', va='bottom', fontsize=7)
    plt.tight_layout()
    plt.show()

plot_distribution(style_train,  'style',  style_classes,  'Style Distribution (Train)')
plot_distribution(genre_train,  'genre',  genre_classes,  'Genre Distribution (Train)')
plot_distribution(artist_train, 'artist', artist_classes, 'Artist Distribution (Train)')

## 3. Class Imbalance Summary

In [None]:
def imbalance_summary(df, label_col, class_map):
    counts = df[label_col].value_counts()
    ratio = counts.max() / counts.min()
    most  = class_map.get(counts.idxmax(), str(counts.idxmax()))
    least = class_map.get(counts.idxmin(), str(counts.idxmin()))
    print(f'  Max: {most} ({counts.max():,})')
    print(f'  Min: {least} ({counts.min():,})')
    print(f'  Imbalance ratio: {ratio:.1f}:1')

for name, df, col, cmap in [
    ('Style',  style_train,  'style',  style_classes),
    ('Genre',  genre_train,  'genre',  genre_classes),
    ('Artist', artist_train, 'artist', artist_classes),
]:
    print(f'\n{name}:')
    imbalance_summary(df, col, cmap)

## 4. Image Overlap Across Tasks

In [None]:
style_paths  = set(style_train['path'])
genre_paths  = set(genre_train['path'])
artist_paths = set(artist_train['path'])

all_paths = style_paths | genre_paths | artist_paths
print(f'Total unique images (train): {len(all_paths):,}')
print(f'In all 3 tasks:  {len(style_paths & genre_paths & artist_paths):,}')
print(f'Style only:      {len(style_paths - genre_paths - artist_paths):,}')
print(f'Genre only:      {len(genre_paths - style_paths - artist_paths):,}')
print(f'Artist only:     {len(artist_paths - style_paths - genre_paths):,}')

## 5. Sample Images (requires downloaded images)

In [None]:
IMAGE_DIR = DATA_DIR / 'images'

def show_samples(df, label_col, class_map, n=8, title=''):
    sample = df.sample(min(n, len(df)), random_state=42)
    fig, axes = plt.subplots(1, len(sample), figsize=(3 * len(sample), 3))
    if len(sample) == 1:
        axes = [axes]
    for ax, (_, row) in zip(axes, sample.iterrows()):
        img_path = IMAGE_DIR / row['path']
        if img_path.exists():
            ax.imshow(mpimg.imread(str(img_path)))
        else:
            ax.text(0.5, 0.5, 'N/A', ha='center', va='center', transform=ax.transAxes)
        ax.set_title(class_map.get(row[label_col], str(row[label_col])), fontsize=8)
        ax.axis('off')
    fig.suptitle(title, fontsize=12, fontweight='bold')
    plt.tight_layout()
    plt.show()

show_samples(style_train,  'style',  style_classes,  title='Style Samples')
show_samples(genre_train,  'genre',  genre_classes,  title='Genre Samples')
show_samples(artist_train, 'artist', artist_classes, title='Artist Samples')

## 6. WeightedRandomSampler Weights Preview

In [None]:
def compute_weights(df, label_col):
    counts = df[label_col].value_counts()
    total = counts.sum()
    n_classes = len(counts)
    weight_per_class = total / (n_classes * counts)
    return df[label_col].map(weight_per_class).tolist()

style_weights = compute_weights(style_train, 'style')
print(f'Style weight range: {min(style_weights):.4f} – {max(style_weights):.4f}')
print(f'Ratio: {max(style_weights)/min(style_weights):.1f}x  (mirrors imbalance)')