In [1]:
import pandas as pd
import numpy as np
import pickle
import os
from PIL import Image

In [2]:
DATA_FOLDER = 'data/domainnet/'

In [3]:
domains = [
    'clipart',
    'infograph',
    'painting',
    'real',
    'sketch',
    'quickdraw',
]

min_samples_per_class = 180

In [4]:
## We only select the most represented classes accross all domains

classes = []
for i, domain in enumerate(domains):
    domain_folder = DATA_FOLDER + domain
    if i == 0:
        classes = os.listdir(domain_folder)
    for cla in os.listdir(domain_folder):
        class_folder = domain_folder + '/' + cla
        n_samples_in_class = len(os.listdir(class_folder))
        if cla in classes and n_samples_in_class < min_samples_per_class:
            classes.remove(cla)

classes

['bird',
 'whale',
 'circle',
 'suitcase',
 'squirrel',
 'feather',
 'strawberry',
 'triangle',
 'teapot',
 'sea_turtle',
 'bread',
 'windmill',
 'zebra',
 'submarine',
 'tiger',
 'headphones',
 'shark']

In [5]:
def preprocess_image(image_filename, img_size=(32, 32)):
    with open(image_filename, 'rb') as f:
        with Image.open(f) as img:
            preprocessed_image = img.convert('RGB').resize(img_size)
    return np.array(preprocessed_image)

In [6]:
x_train, x_test, y_train, y_test = {}, {}, {}, {}

for domain in domains:
    domain_folder = DATA_FOLDER + domain
    x_train[domain], x_test[domain], y_train[domain], y_test[domain] = [], [], [], []
    for i, cla in enumerate(classes):
        class_folder = domain_folder + '/' + cla
        for j, image_filename in enumerate(os.listdir(class_folder)):
            if j < min_samples_per_class:
                img = preprocess_image(domain_folder + '/' + cla + '/' + image_filename)
                if j < min_samples_per_class // 3:
                    x_test[domain].append(img)
                    y_test[domain].append(i)
                else:
                    x_train[domain].append(img)
                    y_train[domain].append(i)
    x_train[domain] = np.moveaxis(np.array(x_train[domain]).astype(float), -1, 1)
    x_test[domain] = np.moveaxis(np.array(x_test[domain]).astype(float), -1, 1)
    y_train[domain], y_test[domain] = np.array(y_train[domain]), np.array(y_test[domain])

In [7]:
with open(DATA_FOLDER + 'preprocessed_domainnet.pickle', 'wb') as handle:
    pickle.dump((x_train, x_test, y_train, y_test), handle, protocol=pickle.HIGHEST_PROTOCOL)