In [1]:
import os
import torch
import clip
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from tqdm import tqdm
import shutil
import data_utils

device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model, clip_preprocess = clip.load("ViT-B/16", device=device)
dataset_list = ["CIFAR10"]
NUM_FS = 8 #Num of Few-shot Image
for dataset_name in dataset_list:
    split = "train"
    output_root = "./data/selected_image"
    
    dataset = data_utils.get_data(dataset_name, split, transform=clip_preprocess)
    class_names = data_utils.get_class_names(dataset_name)
    
    class_to_indices = {}
    for idx, (_, label) in enumerate(dataset.samples):
        class_to_indices.setdefault(label, []).append(idx)
    
    for label, indices in class_to_indices.items():
        if len(indices) < NUM_FS:
            print(f"Skip {class_names[label]}")
            continue
        
        subset = Subset(dataset, indices)
        dataloader = DataLoader(subset, batch_size=512, num_workers=8, pin_memory=True)
        
        all_features = []
        all_paths = [dataset.samples[i][0] for i in indices]
    
        with torch.no_grad():
            for images, _ in tqdm(dataloader, desc=f"Processing {class_names[label]}"):
                images = images.to(device)
                features = clip_model.encode_image(images)
                features = F.normalize(features, dim=1)
                all_features.append(features.cpu())
        all_features = torch.cat(all_features, dim=0)  # [N, D]
    
        similarity = all_features @ all_features.T
        distance = 1 - similarity
    
        selected = [0]
        remaining = set(range(len(all_features))) - set(selected)
        while len(selected) < NUM_FS and remaining:

            remaining_list = list(remaining)
 
            dist_subset = distance[selected][:, remaining_list]  # shape: [len(selected), len(remaining)]

            mean_dists = dist_subset.mean(dim=0)  # [len(remaining)]
            
            farthest_idx_in_remaining = torch.argmax(mean_dists).item()
            farthest_idx = remaining_list[farthest_idx_in_remaining]
            
            selected.append(farthest_idx)
            remaining.remove(farthest_idx)
    
        save_dir = os.path.join(output_root, dataset_name, class_names[label])
        os.makedirs(save_dir, exist_ok=True)
        for idx_sel in selected:
            src_path = all_paths[idx_sel]
            dst_path = os.path.join(save_dir, os.path.basename(src_path))
            shutil.copy(src_path, dst_path)
    
        print(f"Finish {class_names[label]}")
    
    print("All Done")

Processing airplane: 100%|██████████| 1/1 [00:03<00:00,  3.03s/it]


Finish airplane


Processing automobile: 100%|██████████| 1/1 [00:01<00:00,  1.90s/it]


Finish automobile


Processing bird: 100%|██████████| 1/1 [00:02<00:00,  2.14s/it]


Finish bird


Processing cat: 100%|██████████| 1/1 [00:02<00:00,  2.03s/it]


Finish cat


Processing deer: 100%|██████████| 1/1 [00:01<00:00,  1.88s/it]


Finish deer


Processing dog: 100%|██████████| 1/1 [00:02<00:00,  2.32s/it]


Finish dog


Processing frog: 100%|██████████| 1/1 [00:02<00:00,  2.26s/it]


Finish frog


Processing horse: 100%|██████████| 1/1 [00:02<00:00,  2.36s/it]


Finish horse


Processing ship: 100%|██████████| 1/1 [00:02<00:00,  2.25s/it]


Finish ship


Processing truck: 100%|██████████| 1/1 [00:02<00:00,  2.24s/it]

Finish truck
All Done



