# Create Mini-ImageNet dataset

This file is not applicable for use anymore since we can just use the .pickle files.

In [65]:
import os
import random
from collections import defaultdict
import json
from sklearn.model_selection import train_test_split
import shutil


In [88]:
def split_miniimagenet_by_class(json_file_path, image_base_path="path/to/dataset", train_classes=64, val_classes=16, test_classes=20,
                               data_dir="./../materials/mini-imagenet"):
    """
    Reads a miniImagenet JSON file, splits it based on a fixed number of classes for
    training, validation, and testing sets, and moves the corresponding images and creates ground truth files.

    Args:
    json_file_path (str): Path to the JSON file containing miniImagenet data.
    image_base_path (str, optional): Base path where the images are located. Defaults to "data".
    train_classes (int, optional): Number of classes for the training set. Defaults to 64.
    val_classes (int, optional): Number of classes for the validation set. Defaults to 16.
    test_classes (int, optional): Number of classes for the test set. Defaults to 20.
    data_dir (str, optional): Directory to store the split data. Defaults to "miniimagenet_split_by_class".
    """

    # Create directories for train, validation, and test sets
    os.makedirs(os.path.join(data_dir, "train"), exist_ok=True)
    os.makedirs(os.path.join(data_dir, "val"), exist_ok=True)
    os.makedirs(os.path.join(data_dir, "test"), exist_ok=True)

    # Read JSON data
    with open(json_file_path, "r") as f:
        data = json.load(f)

    # Extract image names and labels
    image_names, labels = data["image_names"], data["image_labels"]

    # Organize data by class
    class_data = defaultdict(list)
    for image_name, label in zip(image_names, labels):
    # Extract actual image name without the "filelists" prefix (assuming the format is "filelists/miniImagenet/n01532829/n01532829_721.JPEG")
        class_name =  image_name.split("/")[-2]
        image_name = image_name.split("/")[-1]
        class_data[class_name].append((image_name, label))

    # Randomly select classes for each set (preserves class separation)
    random.seed(42)  # For reproducibility
    all_classes = list(class_data.keys())
    train_classes = random.sample(all_classes, train_classes)
    remaining_classes = set(all_classes) - set(train_classes)
    val_classes = random.sample(list(remaining_classes), val_classes)
    test_classes = list(remaining_classes - set(val_classes))

    # Move images and create ground truth files
    def move_data(data_dir_subset, class_subset):
        
        for class_label in class_subset:
            class_dir = os.path.join(data_dir_subset, class_label)
            os.makedirs(class_dir, exist_ok=True)
            # print(class_dir)
            for image_name, label in class_data[class_label]:
                class_path = os.path.join(image_base_path, class_label)
                image_path = os.path.join(class_path, image_name)  # Use the provided image_base_path
                full_target_path = os.path.join(class_dir, image_name)
                    # Create ground truth file
                with open(os.path.join(class_dir, "ground_truth.txt"), "a+") as f:
                    f.write(str(label) + "\n")
                    
                # Check if the image file exists (optional)
                if os.path.exists(image_path):
                    shutil.move(image_path, full_target_path)  # Move the image
                else:
                    print(f"Warning: Image {image_path} not found. Skipping...")



    # # Move data to train, validation, and test sets based on selected classes
    move_data(os.path.join(data_dir, "train"), train_classes)
    move_data(os.path.join(data_dir, "val"), val_classes)
    move_data(os.path.join(data_dir, "test"), test_classes)
                

In [89]:
# Example usage
split_miniimagenet_by_class(json_file_path="./../materials/mini-imagenet_split.json", image_base_path="D:/Downloads/archive")

