In [1]:
from datasets import load_dataset, DatasetDict, load_from_disk
from tqdm import tqdm 
from collections import Counter
import numpy as np
import matplotlib.pyplot as plt
import math
from sklearn.decomposition import PCA

save_dir = "../Data/"

In [None]:
# download and save raw data
import os
os.makedirs(save_dir, exist_ok = True) #if dir not made make it else nothing
ds = load_dataset("jlbaker361/wikiart")
ds.save_to_disk(save_dir)

In [3]:
# load data from disk
ds = load_from_disk(save_dir, keep_in_memory=True)

In [None]:
import sys
import pandas as pd
# train_ds = ds['train'].remove_columns(['text', 'name', 'gen_style'])
train_ds = ds['train'].remove_columns(['image'])
train_ds
test_df = pd.DataFrame(ds['train'])
print(sys.getsizeof(test_df), sys.getsizeof(test_df)*10**(-9))

In [None]:
test_df.head()

In [None]:
ds


In [None]:
image_widths = []
image_heights = []
for img in tqdm(ds['train'], desc="Extracting dimensions"):
    if 'image' in img:
        # Extract width and height directly from the PIL image object
        width, height = img['image'].size
        image_widths.append(width)
        image_heights.append(height)
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.hist(image_widths, bins=30, color='blue', alpha=0.7)
plt.title("Distribution of Image Widths")
plt.xlabel("Width (pixels)")
plt.ylabel("Frequency")
plt.subplot(1, 2, 2)
plt.hist(image_heights, bins=30, color='green', alpha=0.7)
plt.title("Distribution of Image Heights")
plt.xlabel("Height (pixels)")
plt.ylabel("Frequency")
plt.tight_layout()
plt.show()





In [None]:
print(f"MIN image size={min(image_widths), min(image_heights)}")
print(f"MAX image size={max(image_widths), max(image_heights)}")

In [9]:
# 2. Class Imbalance

In [None]:
styles = [img['style'] for img in tqdm(ds['train'], desc="Extracting styles") if 'style' in img]
style_counts = Counter(styles)
del styles
print("Class distribution in 'style':")
# this ensure the figure is in sorted order
keys = []
values =  []
for style, count in style_counts.most_common():
    print(f"{style}: {count}")
    keys.append(style)
    values.append(count)
plt.figure(figsize=(10, 6))
plt.bar(keys, values, color='skyblue', alpha=0.7)
plt.xticks(rotation=90) 
plt.title("Distribution of Painting Styles")
plt.xlabel("Style")
plt.ylabel("Frequency")
plt.tight_layout()
plt.show()
del keys, values, style_counts

In [11]:
# 3. 2 Dimensional PCA with different colors for classes.

In [None]:

HEIGHT = 64
WIDTH = 64

# img_arr = [np.array(x['image'].resize((WIDTH, HEIGHT))).reshape(-1) 
#            for x in tqdm(ds['train'], desc="Processing Images for PCA") if 'image' in x]
# converts image to 1d np array by flattening (including rgb channels)
def convert_img(x):
    import numpy as np
    HEIGHT = 64
    WIDTH = 64
    x['img_pixels'] = np.array(x['image'].resize((WIDTH, HEIGHT))).reshape(-1)/255
    return x

# TODO use full train set
train_ds = ds['train'].select(range(5000))
# convert pil image to resized and normalized pixel values
train_ds = train_ds.map(convert_img, num_proc=4)
# remove examples that do not match the actual length
# alternatively we can pad images that do not have all channels
train_ds = train_ds.filter(lambda x: len(x['img_pixels'])==HEIGHT*WIDTH*3)


In [None]:
def get_unique_styles(train_ds: DatasetDict):
    unique_style_set = set()
    for x in tqdm(train_ds):
        unique_style_set.add(x['style'])
    return list(unique_style_set)

unique_styles = get_unique_styles(train_ds)
unique_styles

In [None]:
def style2num(style, style_list):
    return style_list.index(style)

def add_style(x, style_list):
    x['style_num'] = style2num(x['style'], style_list)
    return x

train_ds = train_ds.map(lambda x: add_style(x, unique_styles), num_proc=4)

In [59]:
pca = PCA(n_components=2)
reduced_img_ar = pca.fit_transform(train_ds['img_pixels'])


In [None]:
plt.scatter(reduced_img_ar[:,0], reduced_img_ar[:,1], c=train_ds['style_num'])
plt.xlabel('PC1')
plt.ylabel('PC2')
plt.title('Visualization of the Images after PCA. Color coded by Image Style')
plt.show()

In [None]:
# 3a. Alternative to PCA, Average pixel brightness of images

In [17]:
brightness_values = []

for img in ds['train']:
    if 'image' in img:
        grayscale_img = img['image'].convert('L')
        np_img = np.array(grayscale_img)
        brightness_values.append(np.mean(np_img))

In [None]:
plt.figure(figsize=(8, 6))
plt.hist(brightness_values, bins=30, color='gray', alpha=0.7)
plt.axvline(sum(brightness_values)/len(brightness_values), color='red', linestyle='dashed', linewidth=1)
plt.title("Distribution of Image Brightness")
plt.xlabel("Brightness (mean pixel value)")
plt.ylabel("Frequency")
plt.show()A

In [None]:
# 4. Color Distributions if u can fix would be dope, super tired rn. need to batch somehow

In [19]:
def visualize_rgb_distribution(dataset: DatasetDict, batch_size=32, num_batches=3):
    train_dataset = dataset['train']
    
    for batch in range(num_batches):
        start_idx = batch * batch_size
        end_idx = start_idx + batch_size
        batch_samples = train_dataset.select(range(start_idx, end_idx))
        # batch size x [img, R, G, B]
        fig, axes = plt.subplots(batch_size, 4, figsize=(20, 5*batch_size))
        fig.suptitle(f'RGB Distribution - Batch {batch+1}', fontsize=16)
        
        for i, sample in enumerate(batch_samples):
            img = sample['image']
            img_array = np.array(img)
            
            axes[i, 0].imshow(img)
            axes[i, 0].set_title(f'Image {start_idx+i+1}')
            axes[i, 0].axis('off')
            for j, color in enumerate(['Red', 'Green', 'Blue']):
                channel_data = img_array[:,:,j].ravel()
                axes[i, j+1].hist(channel_data, bins=256, range=(0,255), color=color.lower(), alpha=0.7)
                axes[i, j+1].set_title(f'{color} Channel')
                axes[i, j+1].set_xlim(0, 255)
                axes[i, j+1].set_ylim(0, img_array.shape[0]*img_array.shape[1]//10)  # Limit y-axis for better visibility
        
        plt.tight_layout()
        plt.show()


In [None]:
batch_size = 5
# if we want to iterate over the whole dataset
num_batches=math.floor(len(ds['train'])/batch_size)
visualize_rgb_distribution(ds, batch_size=batch_size, num_batches=1)

In [32]:
pixel_values = []
# Batching somewhere idk how 
for img in tqdm(ds['train'], desc="Extracting pixel values"):
    if 'image' in img:
        grayscale_img = img['image'].convert('L')
        np_img = np.array(grayscale_img).flatten()
        pixel_values.extend(np_img)
plt.figure(figsize=(8, 6))
plt.hist(pixel_values, bins=256, color='black', alpha=0.7)
plt.title("Distribution of Pixel Intensities")
plt.xlabel("Pixel Intensity")
plt.ylabel("Frequency")
plt.show()


Processing batch 1/147:   0%|          | 0/5 [00:00<?, ?it/s]


TypeError: string indices must be integers, not 'str'

Processing batch 1/147:   0%|          | 0/5 [00:00<?, ?it/s]


TypeError: string indices must be integers, not 'str'

In [None]:
# 5. Small sample of the images

In [24]:
def display_first_images(dataset: DatasetDict, num_images=5):
    samples = dataset['train'].select(range(num_images))
    plt.figure(figsize=(12, 8))
    for i, img_data in enumerate(samples):
        img = img_data['image']
        plt.subplot(1, num_images, i + 1)
        plt.imshow(img)
        plt.axis('off')
        plt.title(img_data.get('style'))
    plt.tight_layout()
    plt.show()

In [None]:
display_first_images(ds, num_images=5)