In [2]:
import os
import csv
from torchvision import datasets
from torchvision.transforms import ToTensor
from PIL import Image

# Step 1: Load the FashionMNIST dataset
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

# Step 2: Define directories
output_dir = "fashion_mnist_images"
train_dir = os.path.join(output_dir, "train_images")
test_dir = os.path.join(output_dir, "test_images")
labels_dir = os.path.join(output_dir, "labels")  # Directory for CSV files

# Create directories if they don't exist
os.makedirs(train_dir, exist_ok=True)
os.makedirs(test_dir, exist_ok=True)
os.makedirs(labels_dir, exist_ok=True)

# Step 3: Function to save images and create a CSV file
def save_images_and_csv(dataset, image_directory, csv_directory, csv_filename):
    """
    Save images from the dataset into the specified directory and create a CSV file
    mapping filenames to their labels.

    Args:
        dataset: Dataset object containing images and labels.
        image_directory: Directory to save the images.
        csv_directory: Directory to save the CSV file.
        csv_filename: Name of the CSV file.
    """
    csv_path = os.path.join(csv_directory, csv_filename)
    with open(csv_path, mode='w', newline='') as csv_file:
        writer = csv.writer(csv_file)
        writer.writerow(["file_name", "label"])  # Write header row

        for index, (image, label) in enumerate(dataset):
            # Convert tensor to PIL image
            image = Image.fromarray((image.numpy().squeeze() * 255).astype('uint8'))
            # Define file name
            file_name = f"{index}.png"
            # Save the image
            image.save(os.path.join(image_directory, file_name))
            # Write the file name and label to the CSV
            writer.writerow([file_name, label])

# Step 4: Save training data
save_images_and_csv(training_data, train_dir, labels_dir, "train_labels.csv")

# Step 5: Save testing data
save_images_and_csv(test_data, test_dir, labels_dir, "test_labels.csv")

print(f"Images saved in {output_dir}/train and {output_dir}/test.")
print(f"Labels saved in {output_dir}/labels.")

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 26421880/26421880 [00:03<00:00, 7863625.43it/s] 


Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 29515/29515 [00:00<00:00, 124746.45it/s]


Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 4422102/4422102 [00:01<00:00, 2309592.82it/s]


Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 5148/5148 [00:00<00:00, 8860187.52it/s]


Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw

Images saved in fashion_mnist_images/train and fashion_mnist_images/test.
Labels saved in fashion_mnist_images/labels.
