In [None]:
###### Connect to the google drive ########
from google.colab import drive
drive.mount('/content/drive')

In [None]:
###### Randomly select some images of small_mammals and trees ########
import os
import shutil
import random
from torchvision import datasets

google_drive_path = '/content/drive/MyDrive/create_forms'

def randomly_select_images(dataset_path, output_path, superclasses, num):
  # Create folder
  os.makedirs(output_path, exist_ok=True)
  os.makedirs(dataset_path, exist_ok=True)
  # Download CIFAR100 dataset
  cifar100_dataset = datasets.CIFAR100(root=dataset_path, train=True, download=True)

  for superclass, classes in superclasses.items():
    # Create google drive folder
    os.makedirs(os.path.join(google_drive_path, output_path), exist_ok=True)
    os.makedirs(os.path.join(google_drive_path, output_path, superclass), exist_ok=True)
    # Create a csv file that stores the label of each image file
    url_path = os.path.join(google_drive_path, output_path, (superclass + '.csv'))
    with open(url_path, 'w') as outfile:
      outfile.write('image_name\tclass_name\tclass_num\n')
    # Create superclass folder
    superclass_folder_path = os.path.join(output_path, superclass)
    os.makedirs(superclass_folder_path, exist_ok=True)

    for subclass in classes:
      # Create classes folder
      subclass_idx = cifar100_dataset.class_to_idx[subclass]
      subclass_folder_path = os.path.join(superclass_folder_path, subclass)
      os.makedirs(subclass_folder_path, exist_ok=True)

      # Get the image index of the specified classes
      indices = [i for i, label in enumerate(cifar100_dataset.targets) if label == subclass_idx]
      # Save as image file
      for i in indices:
        img, label = cifar100_dataset[i]
        img.save(os.path.join(subclass_folder_path, f'{i}.png'))

      # Randomly select num items from each class
      selected_indices = selected_some_images(indices, num)
      # Save selected images to the google drive
      for i in selected_indices:
        img, label = cifar100_dataset[i]
        img.save(os.path.join(google_drive_path, output_path, superclass, f'{i}.png'))
      with open(url_path, 'a') as outfile:
        for i in selected_indices:
          outfile.write('{}.png\t{}\t{}\n'.format(i, subclass, subclass_idx))

  # Organize randomly selected images into zip files
  shutil.make_archive(os.path.join(google_drive_path, output_path, 'cifar100_SM_T'), 'zip', output_path)

# Randomly select num items from each class
def selected_some_images(indices, num):
  random.seed(2024)
  random.shuffle(indices)
  selected_indices = indices[:num]
  return selected_indices

if __name__ == '__main__':
  # CIFAR100 dataset storage location
  dataset_path = 'dataset'
  # Randomly select images storage location
  output_path = 'select_images'
  # Subclasses contained in classes
  get_superclasses = {'small_mammals': ['hamster', 'mouse', 'rabbit', 'shrew', 'squirrel'],
            'trees': ['maple_tree', 'oak_tree', 'palm_tree', 'pine_tree', 'willow_tree']}
  # Randomly select num images from each class
  num = 100

  # Get num images of small_mammals and trees
  randomly_select_images(dataset_path, output_path, get_superclasses, num)