In [1]:
import json
import os
import torch
import torchvision

from ingredients.utilities import set_seed
from ingredients.dataset import common_transform

In [2]:
# ----------
# CONFIGS
# ----------

SEED = 1233

ROOT_DATASET = "/home/lazzaro/data/datasets"

SPLIT_FILENAME = f"caltech101_train_test_split.seed={SEED}.json"

SPLIT_SAVE_FOLDER = "./configs/datasets"

# ----------

# % train_size, remain is test
TRAIN_SIZE = 0.8

In [3]:
set_seed(SEED)

os.makedirs(SPLIT_SAVE_FOLDER, exist_ok=True)

transform = common_transform()

dataset = torchvision.datasets.Caltech101(root=ROOT_DATASET, transform=transform, download=False)

dataset_indices = list(range(len(dataset)))

train_size = int(TRAIN_SIZE * len(dataset))
test_size = len(dataset) - train_size

set_seed(SEED)
shuffled_indices = torch.randperm(len(dataset)).tolist()
train_indices = shuffled_indices[:train_size]
test_indices = shuffled_indices[train_size:]

save_path = os.path.join(SPLIT_SAVE_FOLDER, SPLIT_FILENAME)
with open(save_path, "w") as f:
        json.dump({"train": train_indices, "test": test_indices}, f)

In [5]:
dataset.categories

['Faces',
 'Faces_easy',
 'Leopards',
 'Motorbikes',
 'accordion',
 'airplanes',
 'anchor',
 'ant',
 'barrel',
 'bass',
 'beaver',
 'binocular',
 'bonsai',
 'brain',
 'brontosaurus',
 'buddha',
 'butterfly',
 'camera',
 'cannon',
 'car_side',
 'ceiling_fan',
 'cellphone',
 'chair',
 'chandelier',
 'cougar_body',
 'cougar_face',
 'crab',
 'crayfish',
 'crocodile',
 'crocodile_head',
 'cup',
 'dalmatian',
 'dollar_bill',
 'dolphin',
 'dragonfly',
 'electric_guitar',
 'elephant',
 'emu',
 'euphonium',
 'ewer',
 'ferry',
 'flamingo',
 'flamingo_head',
 'garfield',
 'gerenuk',
 'gramophone',
 'grand_piano',
 'hawksbill',
 'headphone',
 'hedgehog',
 'helicopter',
 'ibis',
 'inline_skate',
 'joshua_tree',
 'kangaroo',
 'ketch',
 'lamp',
 'laptop',
 'llama',
 'lobster',
 'lotus',
 'mandolin',
 'mayfly',
 'menorah',
 'metronome',
 'minaret',
 'nautilus',
 'octopus',
 'okapi',
 'pagoda',
 'panda',
 'pigeon',
 'pizza',
 'platypus',
 'pyramid',
 'revolver',
 'rhino',
 'rooster',
 'saxophone',
 'sc