In [None]:
import torch
import torchvision
from torchvision import transforms
from PIL import Image
import pandas as pd
import numpy as np
from pathlib import Path
import os
import shutil
from zipfile import ZipFile

## Data Extraction

If you don't have the `.h5` file downloaded..run this code

In [None]:
from astroNN.datasets import load_galaxy10
import numpy as np

images, labels = load_galaxy10()

Since I already have a `.h5` file downloaded, I will run the folllowing block of code

In [5]:
import h5py
import numpy as np

h5_path = Path("/home/asus/Desktop/Galaxy_type_classification_project/Galaxy10_DECals.h5")

with h5py.File(h5_path, "r") as f:
    images = np.array(f["images"])  # (N, 256, 256, 3)
    labels = np.array(f["ans"])     # (N,)

In [12]:
print(images.shape, images.dtype)
print(labels.shape, labels.dtype)
print(np.unique(labels))

(17736, 256, 256, 3) uint8
(17736,) uint8
[0 1 2 3 4 5 6 7 8 9]


In [6]:
CLASS_NAMES = [
    "disturbed",
    "merging",
    "round_smooth",
    "in_between_round_smooth",
    "cigar_shaped",
    "barred_spiral",
    "unbarred_tight_spiral",
    "unbarred_loose_spiral",
    "edge_on_no_bulge",
    "edge_on_with_bulge"
]

In [7]:
from sklearn.model_selection import train_test_split

indices = np.arange(len(images))

train_idx, temp_idx = train_test_split(
    indices, test_size=0.2, stratify=labels, random_state=42
)

val_idx, test_idx = train_test_split(
    temp_idx,
    test_size=0.5,
    stratify=labels[temp_idx],
    random_state=42
)

In [9]:
from pathlib import Path

base_dir = Path("../data/Galaxy10_DECaLS")

for split in ["train", "val", "test"]:
    for cls in CLASS_NAMES:
        (base_dir / split / cls).mkdir(parents=True, exist_ok=True)

In [10]:
def save_split(indices, split_name):
    for i in indices:
        cls_name = CLASS_NAMES[labels[i]]
        img = Image.fromarray(images[i])
        img_path = base_dir / split_name / cls_name / f"img_{i}.jpg"
        img.save(img_path, quality=95)

save_split(train_idx, "train")
save_split(val_idx, "val")
save_split(test_idx, "test")