In [11]:
import torch
from torchvision import datasets, transforms
from torchvision.utils import save_image
import os

# --- Configuration ---
output_filename = "img_3.png"
data_root = "./data"
sample_index = 4 # Index of the sample to fetch (e.g., the first image)

# --- Load MNIST Dataset ---
# We only need ToTensor initially to get the image as a tensor
try:
    mnist_dataset = datasets.MNIST(
        root=data_root,
        train=True, # Or False, doesn't matter much for one sample
        download=True,
        transform=transforms.ToTensor() # Converts image to [0, 1] tensor
    )
except Exception as e:
    print(f"Error loading MNIST dataset: {e}")
    exit()


# --- Get a Single Sample ---
if sample_index >= len(mnist_dataset):
    print(f"Error: Sample index {sample_index} is out of bounds for MNIST dataset (size {len(mnist_dataset)}).")
else:
    image_tensor, label = mnist_dataset[sample_index]
    print(f"Loaded sample {sample_index} with label: {label}")
    # image_tensor shape is initially [1, H, W] (e.g., [1, 28, 28])

    # --- Convert to 3 Channels ---
    # Repeat the grayscale channel 3 times along the channel dimension
    image_3channel = image_tensor.repeat(3, 1, 1)
    # image_3channel shape is now [3, H, W] (e.g., [3, 28, 28])

    # --- Save the Image ---
    # save_image handles tensors in the range [0, 1] by default
    try:
        save_image(image_3channel, output_filename)
        print(f"Successfully saved 3-channel MNIST image to {output_filename}")
        print(f"Image details - Shape: {image_3channel.shape}, Min: {image_3channel.min():.2f}, Max: {image_3channel.max():.2f}")
    except Exception as e:
        print(f"Error saving image: {e}")


Loaded sample 4 with label: 9
Successfully saved 3-channel MNIST image to img_3.png
Image details - Shape: torch.Size([3, 28, 28]), Min: 0.00, Max: 1.00
