In [1]:
import os
import random

def load_data_from_file(file_path):
    """ Load data from a file. """
    with open(file_path, 'r') as file:
        return file.readlines()

def sample_data(data, num_samples=500):
    """ Sample data ensuring representation from each category. """
    # Group data by category
    class_dict = {}
    for line in data:
        class_id = line.strip().split()[-1]  # Assuming the class ID is at the end of each line
        if class_id not in class_dict:
            class_dict[class_id] = []
        class_dict[class_id].append(line.strip())

    # Calculate minimum number of samples per class
    num_classes = len(class_dict)
    min_samples_per_class = num_samples // num_classes
    remaining_samples = num_samples % num_classes

    # Sample data
    sampled_data = []
    indices_sampled = set()  # Set to track sampled indices to avoid repeats
    for class_id, items in class_dict.items():
        num_to_sample = min_samples_per_class + (1 if remaining_samples > 0 else 0)
        if len(items) >= num_to_sample:
            chosen_samples = random.sample(items, num_to_sample)
            sampled_data.extend(chosen_samples)
            indices_sampled.update(chosen_samples)  # Add chosen samples to the set
        else:
            sampled_data.extend(items)
            indices_sampled.update(items)  # Add all items to the set since we're using them all
        remaining_samples -= 1 if remaining_samples > 0 else 0

    # If we have not reached the desired number of samples, continue sampling from classes with surplus
    while len(sampled_data) < num_samples:
        for class_id in class_dict:
            items = class_dict[class_id]
            items_left = list(set(items) - indices_sampled)  # Only consider items not already sampled
            if not items_left:
                continue
            sampled_data.append(random.choice(items_left))
            indices_sampled.update(sampled_data[-1:])  # Update the sampled set with the new item
            if len(sampled_data) >= num_samples:
                break

    return sampled_data

def write_sampled_data(sampled_data, output_file):
    """ Write the sampled data to a file. """
    with open(output_file, 'w') as file:
        for item in sampled_data:
            file.write(item + '\n')

def main():
    base_directory = "/media/ruanjiacheng/新加卷/ecodes/Prompt/data/vtab-1k"
    
    # Process each subdirectory independently
    for subdir in os.listdir(base_directory):
        subdir_path = os.path.join(base_directory, subdir)

        test_file_path = os.path.join(subdir_path, 'test.txt')
    
        # Check if the test.txt exists
        if os.path.isfile(test_file_path):
            data = load_data_from_file(test_file_path)
            sampled_data = sample_data(data)
            output_file = os.path.join(subdir_path, 'test_adv_500.txt')
            write_sampled_data(sampled_data, output_file)
            print(f"Sampling complete. Data written to {output_file} with sample len of {len(sampled_data)}")

if __name__ == "__main__":
    main()


Sampling complete. Data written to /media/ruanjiacheng/新加卷/ecodes/Prompt/data/vtab-1k/caltech101/test_adv_500.txt with sample len of 500
Sampling complete. Data written to /media/ruanjiacheng/新加卷/ecodes/Prompt/data/vtab-1k/cifar/test_adv_500.txt with sample len of 500
Sampling complete. Data written to /media/ruanjiacheng/新加卷/ecodes/Prompt/data/vtab-1k/clevr_count/test_adv_500.txt with sample len of 500
Sampling complete. Data written to /media/ruanjiacheng/新加卷/ecodes/Prompt/data/vtab-1k/clevr_dist/test_adv_500.txt with sample len of 500
Sampling complete. Data written to /media/ruanjiacheng/新加卷/ecodes/Prompt/data/vtab-1k/diabetic_retinopathy/test_adv_500.txt with sample len of 500
Sampling complete. Data written to /media/ruanjiacheng/新加卷/ecodes/Prompt/data/vtab-1k/dmlab/test_adv_500.txt with sample len of 500
Sampling complete. Data written to /media/ruanjiacheng/新加卷/ecodes/Prompt/data/vtab-1k/dsprites_loc/test_adv_500.txt with sample len of 500
Sampling complete. Data written to /me