In [None]:
from datasets import load_dataset
import matplotlib.pyplot as plt
from PIL import Image
import pandas as pd
from collections import Counter

In [None]:
print("Loading NSFW classification dataset...")

# Using FalconLLM/nsfw_image_dataset (balanced, pre-labeled)
try:
    dataset = load_dataset("FalconLLM/nsfw_image_dataset", split="train")
    print(f"✅ Loaded {len(dataset)} images")
    print(f"Columns: {dataset.column_names}")
    print(f"\nSample entry: {dataset[0]}")
except Exception as e:
    print(f"❌ Error: {e}")
    print("Trying alternative dataset...")

In [None]:
labels = [item['label'] for item in dataset]
label_counts = Counter(labels)
print("\nLabel Distribution:")
for label, count in label_counts.items():
    print(f"  {label}: {count} ({count/len(labels)*100:.1f}%)")

In [None]:
fig, axes = plt.subplots(2, 4, figsize=(15, 8))
safe_samples = [item for item in dataset if item['label'] == 'safe'][:8]

for idx, ax in enumerate(axes.flat):
    if idx < len(safe_samples):
        img = safe_samples[idx]['image']
        ax.imshow(img)
        ax.set_title(f"Label: {safe_samples[idx]['label']}")
        ax.axis('off')
        
plt.tight_layout()
plt.savefig('../results/sample_safe_images.png')
plt.show()

print(f"\n✅ Saved sample visualization to results/sample_safe_images.png")