In [1]:
import os
import glob
import shutil
import numpy as np

import utils

In [2]:
data_dir = './data/'
train_dir = os.path.join(data_dir, 'train_sub')
class_labels = os.listdir(train_dir)
class_labels.sort()
print(class_labels)

['Apple___Apple_scab', 'Apple___Black_rot', 'Apple___healthy', 'Corn___Common_rust', 'Corn___Northern_Leaf_Blight', 'Corn___healthy', 'Grape___Black_rot', 'Grape___Esca_(Black_Measles)', 'Grape___healthy', 'Tomato___Early_blight', 'Tomato___Septoria_leaf_spot', 'Tomato___healthy']


In [3]:
# Read image list into a dictionary for each class.
train_img_dict = {}
for class_label in class_labels:
    class_folder = os.path.join(train_dir, class_label)
    img_list = glob.glob(os.path.join(class_folder, '*.JPG')) + glob.glob(os.path.join(class_folder, '*.jpg'))      
    train_img_dict[class_label] = img_list
    print(class_label, ':', len(img_list))

Apple___Apple_scab : 504
Apple___Black_rot : 496
Apple___healthy : 1316
Corn___Common_rust : 953
Corn___Northern_Leaf_Blight : 788
Corn___healthy : 929
Grape___Black_rot : 944
Grape___Esca_(Black_Measles) : 1106
Grape___healthy : 338
Tomato___Early_blight : 800
Tomato___Septoria_leaf_spot : 1416
Tomato___healthy : 1272


### Create client data.

In [4]:
# Create client data directory.
client_data_dir = os.path.join(data_dir, 'client_data')

In [5]:
grouped_class_labels_list = [
    ('Apple___Apple_scab', 'Apple___Black_rot', 'Apple___healthy'),
    ('Corn___Common_rust', 'Corn___Northern_Leaf_Blight', 'Corn___healthy'),
    ('Grape___Black_rot', 'Grape___Esca_(Black_Measles)', 'Grape___healthy'),
    ('Tomato___Early_blight', 'Tomato___Septoria_leaf_spot', 'Tomato___healthy')
]

start_end_pairs = [
    (0, 2),
    (2, 4),
    (4, 6),
    (6, 8),
]

for (start, end), grouped_class_labels in zip(start_end_pairs, grouped_class_labels_list):
    # For clients.
    for client_id in range(start, end):
        # Create data directory for each client.
        current_client_data_dir = os.path.join(client_data_dir, 'client_' + str(client_id))
        img_count_dict = {}
        for class_label in grouped_class_labels:            
            class_dir = os.path.join(current_client_data_dir, class_label)
            os.makedirs(class_dir, exist_ok=True)
            num_img = np.random.randint(100, 200)
            for img_idx in range(num_img):
                img_path = train_img_dict[class_label].pop()
                shutil.copy(img_path, class_dir)
            img_count_dict[class_label] = num_img
        
        # Save data info.     
        note_file = os.path.join(current_client_data_dir, 'notes.txt')
        txt = ''
        for class_label in grouped_class_labels:
            buff = '{:30}: {}\n'.format(class_label, img_count_dict[class_label])
            txt += buff
        utils.save_notes(note_file, txt)

### Centralize client data.

In [6]:
centralized_client_data_dir = os.path.join(data_dir, 'centralized_client_data')
os.makedirs(centralized_client_data_dir, exist_ok=True)

In [7]:
sample_dir = train_dir

class_labels = list()
for item in os.listdir(sample_dir):
    if os.path.isdir(os.path.join(sample_dir, item)):
        class_labels.append(item)

class_labels.sort()
print('len(class_labels):', len(class_labels))
print(class_labels)

len(class_labels): 12
['Apple___Apple_scab', 'Apple___Black_rot', 'Apple___healthy', 'Corn___Common_rust', 'Corn___Northern_Leaf_Blight', 'Corn___healthy', 'Grape___Black_rot', 'Grape___Esca_(Black_Measles)', 'Grape___healthy', 'Tomato___Early_blight', 'Tomato___Septoria_leaf_spot', 'Tomato___healthy']


In [8]:
num_clients = 8

In [9]:
for class_label in class_labels:
    
    cen_class_dir = os.path.join(centralized_client_data_dir, class_label)
    os.makedirs(cen_class_dir, exist_ok=True)
    
    for i in range(num_clients):
        current_client_data_dir = os.path.join(client_data_dir, 'client_' + str(i))
        class_dir = os.path.join(current_client_data_dir, class_label)
        img_list = glob.glob(os.path.join(class_dir, '*.JPG')) + glob.glob(os.path.join(class_dir, '*.jpg')) 
        for idx, img in enumerate(img_list):
            shutil.copy(img, os.path.join(cen_class_dir, 'c_{:02d}_{:02d}.JPG'.format(i, idx)))