In [None]:
import os
import random
import pandas as pd

# Define paths
DATASET_PATH = "./2750"
FEW_SHOT_PATH = "./few_shot_data"  # Where we save the sampled sets
VAL_PATH = "./validation_data"  # Where we save the sampled sets

# Ensure output directory exists
os.makedirs(FEW_SHOT_PATH, exist_ok=True)
os.makedirs(VAL_PATH, exist_ok=True)

# Define number of validation samples per class
VAL_SAMPLES = 50
FEW_SHOT_SIZES = [1, 2, 4, 8, 16]

# Load all images
all_images = {cls: [os.path.join(DATASET_PATH, cls, img) 
                    for img in os.listdir(os.path.join(DATASET_PATH, cls))] 
              for cls in os.listdir(DATASET_PATH) if os.path.isdir(os.path.join(DATASET_PATH, cls))}

# Sample validation set
validation_set = {}
for cls, images in all_images.items():
    validation_set[cls] = random.sample(images, VAL_SAMPLES)
    all_images[cls] = list(set(images) - set(validation_set[cls]))  # Remove validation images

# Save validation set
val_df = pd.DataFrame([(img, cls) for cls, imgs in validation_set.items() for img in imgs], columns=["Image Path", "Class"])
val_df.to_csv(os.path.join(VAL_PATH, "validation_set.csv"), index=False)

# Create few-shot datasets
for num_samples in FEW_SHOT_SIZES:
    few_shot_data = []
    for cls, images in all_images.items():
        few_shot_data.extend([(img, cls) for img in random.sample(images, num_samples)])

    few_shot_df = pd.DataFrame(few_shot_data, columns=["Image Path", "Class"])
    few_shot_df.to_csv(os.path.join(FEW_SHOT_PATH, f"few_shot_{num_samples}.csv"), index=False)

print("Few-shot datasets and validation set successfully created!")
