In [25]:
import glob
import shutil
import os
import numpy as np
import collections

In [20]:
source_directory = "../.downloads/PACS"
destination_directory = "../datasets/PACS_shared"

from_to_list = {}

for domain in glob.glob(f"{source_directory}/*"):
	domain_name = os.path.basename(domain)
	for object in glob.glob(f"{domain}/*"):
		object_name = os.path.basename(object)

		if (domain_name, object_name) not in from_to_list:
			from_to_list[(domain_name, object_name)] = []
		new_dir_postfix = f"{domain_name}_{object_name}"
		
		# prepare all images
		for image in glob.glob(f"{object}/*"):
			newName = f"{domain_name}_{object_name}_{os.path.basename(image)[:-4]}.jpg"
			new_image = os.path.join(new_dir_postfix, newName)

			from_to_list[(domain_name, object_name)].append((image, new_image))

In [21]:
for cls, images in from_to_list.items():
	print(f"{cls} has {len(images)} images")
print(f"in total {sum([len(images) for images in from_to_list.values()])} images")

('photo', 'house') has 280 images
('photo', 'horse') has 199 images
('photo', 'guitar') has 186 images
('photo', 'giraffe') has 182 images
('photo', 'person') has 432 images
('photo', 'elephant') has 202 images
('photo', 'dog') has 189 images
('sketch', 'house') has 80 images
('sketch', 'horse') has 816 images
('sketch', 'guitar') has 608 images
('sketch', 'giraffe') has 753 images
('sketch', 'person') has 160 images
('sketch', 'elephant') has 740 images
('sketch', 'dog') has 772 images
('cartoon', 'house') has 288 images
('cartoon', 'horse') has 324 images
('cartoon', 'guitar') has 135 images
('cartoon', 'giraffe') has 346 images
('cartoon', 'person') has 405 images
('cartoon', 'elephant') has 457 images
('cartoon', 'dog') has 389 images
('art_painting', 'house') has 295 images
('art_painting', 'horse') has 201 images
('art_painting', 'guitar') has 184 images
('art_painting', 'giraffe') has 285 images
('art_painting', 'person') has 449 images
('art_painting', 'elephant') has 255 image

In [33]:
np.random.seed(42)

sampled_stat = {}
sampled_stat['Train'] = collections.defaultdict(int)
sampled_stat['Test'] = collections.defaultdict(int)

for cls, images in from_to_list.items():
	np.random.shuffle(images)
	num_train = int(len(images) * 0.8)

	train_images = images[:num_train]
	test_images = images[num_train:]

	train_dir = os.path.join(destination_directory, f'Train/{"_".join(cls)}')
	test_dir = os.path.join(destination_directory, f'Test/{"_".join(cls)}')

	if not os.path.exists(train_dir):
		os.makedirs(train_dir, exist_ok=True)
	if not os.path.exists(test_dir):
		os.makedirs(test_dir, exist_ok=True)

	for image, new_image in train_images:
		new_image_path = os.path.join(destination_directory, f"Train/{new_image}")
		sampled_stat['Train'][cls] += 1

		if os.path.exists(new_image_path):
			continue
		else:
			shutil.copy(image, new_image_path)

	for image, new_image in test_images:
		new_image_path = os.path.join(destination_directory, f"Test/{new_image}")
		sampled_stat['Test'][cls] += 1

		if os.path.exists(new_image_path):
			continue
		else:
			shutil.copy(image, new_image_path)

In [34]:
sampled_stat

{'Train': defaultdict(int,
             {('photo', 'house'): 224,
              ('photo', 'horse'): 159,
              ('photo', 'guitar'): 148,
              ('photo', 'giraffe'): 145,
              ('photo', 'person'): 345,
              ('photo', 'elephant'): 161,
              ('photo', 'dog'): 151,
              ('sketch', 'house'): 64,
              ('sketch', 'horse'): 652,
              ('sketch', 'guitar'): 486,
              ('sketch', 'giraffe'): 602,
              ('sketch', 'person'): 128,
              ('sketch', 'elephant'): 592,
              ('sketch', 'dog'): 617,
              ('cartoon', 'house'): 230,
              ('cartoon', 'horse'): 259,
              ('cartoon', 'guitar'): 108,
              ('cartoon', 'giraffe'): 276,
              ('cartoon', 'person'): 324,
              ('cartoon', 'elephant'): 365,
              ('cartoon', 'dog'): 311,
              ('art_painting', 'house'): 236,
              ('art_painting', 'horse'): 160,
              ('art_paintin

sanity checks

In [35]:
train = glob.glob(f"{destination_directory}/Train/**/*.jpg", recursive=True)
test = glob.glob(f"{destination_directory}/Test/**/*.jpg", recursive=True)
len(train), len(test)

(7984, 2007)