# A Demo of this Project’s DataLoaders

This project loads images from a collection of galaxy images from the GZ: CANDELS project.

In [None]:
import random
from matplotlib import pyplot as plt
import torch
import torchvision
from torchvision.transforms import ToTensor

from dataset.dataset import get_data_loaders

In [None]:
train_dataset, _, _ = get_data_loaders()

# Get a batch of training data
images, _ = next(iter(train_dataset))

# Print the shape of the images
print(images.shape)

# Plot the first batch of images

In [None]:
def tensor_to_np(tensor_img):
    # Check if the input is a PyTorch tensor
    if not isinstance(tensor_img, torch.Tensor):
        # Convert PIL Image to PyTorch tensor
        tensor_img = ToTensor()(tensor_img)

    # Move tensor to CPU (if not already) and convert to NumPy
    np_img = tensor_img.cpu().permute(1, 2, 0).numpy()
    # Clip values between 0 and 1 for display
    return np.clip(np_img, 0, 1)


# Select 16 random indices from the first 100 images
num_images = 16
random_indices = random.sample(range(10000), num_images)

# Create a 4x4 grid
fig, axes = plt.subplots(4, 4, figsize=(8, 8))

for i, idx in enumerate(random_indices):
    # Load the raw image from the dataset
    raw_image, _ = train_dataset[idx]  # Assuming images are in train_dataset

    # Convert tensor to NumPy format
    np_image = tensor_to_np(raw_image)

    # Get row and column indices
    row, col = divmod(i, 4)

    # Display image
    axes[row, col].imshow(np_image)
    axes[row, col].axis("off")

plt.tight_layout()
plt.show()