# Exploring the Simulated Dataset

This notebook demonstrates the generation and visualization of the simulated shapes dataset.


In [11]:
import matplotlib.pyplot as plt
import numpy as np
import torch

import os
os.chdir("../../..")
from src.diffusion.datasets.simulated_dataset import SimulatedDataset
# For better visualization
%matplotlib inline
plt.rcParams['figure.figsize'] = [15, 15]


In [None]:
"# Create dataset with larger image size\n",
dataset = SimulatedDataset(
    num_samples=10, image_size=1024, max_shapes=3, min_distance=0.5, max_distance=10.0
)
print(f"Dataset size: {len(dataset)}")


In [13]:
def visualize_sample(image_tensor, caption):
    """Helper function to visualize a single sample""",
    # Convert from (C,H,W) to (H,W,C)
    image = image_tensor.numpy().transpose(1, 2, 0)
    plt.figure(figsize=(15, 15))
    plt.imshow(image)
    plt.title(caption, wrap=True, fontsize=12)
    plt.axis("off")
    plt.show()


In [None]:
# Generate and visualize 10 samples
for i in range(10):
    image, caption = dataset[i]
    print(f"\nSample {i+1} caption: {caption}")
    visualize_sample(image, f"Sample {i+1}")


In [None]:
# Create a grid visualization of all samples\n",
fig, axes = plt.subplots(2, 5, figsize=(25, 10))
axes = axes.ravel()

captions = []
for i in range(10):
    image, caption = dataset[i]
    image = image.numpy().transpose(1, 2, 0)

    axes[i].imshow(image)
    axes[i].set_title(f"Sample {i+1}", fontsize=10)
    axes[i].axis("off")
    captions.append(caption)

plt.tight_layout()
plt.show()

# Print all captions
print("\nCaptions for all samples:")
for i, c in enumerate(captions):
    print(f"\nSample {i+1}: {c}")
