In [1]:
import numpy as np
from PIL import Image
import pandas as pd
from sklearn.model_selection import train_test_split
from collections import Counter
from itertools import chain
import os
import sys

root = os.path.abspath("..")
sys.path.append(root)

from config import Config
from common.constants import IMAGE_SHAPE_WITHOUT_CHANNELS, MODALITY_SET
from plant_clef_meta import *

In [2]:
CONFIG = Config(os.path.join(root, "config.json"))
ORGANS = MODALITY_SET

# Loading & Caching Metadata

In [3]:
meta = PlantClefImage.load(CONFIG, from_cache=True)
meta = [img for img in meta if img.content in ORGANS]

PlantClefImage.save(CONFIG, meta, pretty=True)

len(meta)

67812

# Save Fat File
##### Saves it at path `CONFIG.get_plant_clef_fat_file_path()`

In [4]:
save_fat_file = False  # <<<<<<< SET TO TRUE TO SAVE

if save_fat_file:
    fat_file_path = CONFIG.get_plant_clef_fat_file_path()
    os.makedirs(os.path.dirname(fat_file_path), exist_ok=True)

    with open(fat_file_path, "wb") as fat_file:
        count = len(meta)

        for i, m in enumerate(meta):
            path = m.get_image_file_path(CONFIG)

            with Image.open(path) as img:
                if img.mode != "RGB":
                    img = img.convert("RGB")

                img = img.resize(IMAGE_SHAPE_WITHOUT_CHANNELS)
                data = img.getdata()
                data = list(np.reshape(data, -1))
            fat_file.write(bytes(data))

            print(f"\rFinished {i + 1}/{count}", end="")

# Collect Organ Indices and Corresponding Labels

In [4]:
organ_indices = {o: [] for o in ORGANS}
organ_labels = {o: [] for o in ORGANS}

for i, img in enumerate(meta):
    organ = img.content
    if organ not in ORGANS:
        continue

    organ_indices[organ].append(i)
    organ_labels[organ].append(img.class_id)

# Filter Small Classes

In [5]:
filtered_organ_indices = organ_indices.copy()
filtered_organ_labels = organ_labels.copy()

removed_organ_labels = {o: set() for o in ORGANS}

for organ in ORGANS:
    counter = Counter(filtered_organ_labels[organ])

    for label, count in counter.items():
        if count < 10:
            removed_organ_labels[organ].add(label)

for organ in ORGANS:
    indices = filtered_organ_indices[organ]
    labels = filtered_organ_labels[organ]
    removed_labels = removed_organ_labels[organ]

    filtered_organ_indices[organ] = [index for i, index in enumerate(indices) if labels[i] not in removed_labels]
    filtered_organ_labels[organ] = [label for label in labels if label not in removed_labels]


In [6]:
[(k, len(set(x))) for k, x in filtered_organ_labels.items()]

[('Leaf', 449), ('Flower', 862), ('Fruit', 302), ('Stem', 146)]

# Map Labels To [0; N]

In [7]:
def get_class_distribution(meta, included_labels):
    classes = {}

    for img in meta:
        cls = img.class_id
        if cls in included_labels:
            classes[cls] = (classes[cls] + 1) if cls in classes else 1
    
    return np.array(sorted(classes.items(), key=lambda i: i[1], reverse=True))

class_dist = get_class_distribution(meta, set(chain(*filtered_organ_labels.values())))
class_map = {c: i for i, c in enumerate(class_dist[:, 0])}

mapped_filtered_organ_labels = filtered_organ_labels.copy()

for organ in ORGANS:
    labels = mapped_filtered_organ_labels[organ]
    mapped_filtered_organ_labels[organ] = [class_map[label] for label in labels]

len(class_map)

956

# Split the Indices

In [9]:
X_train = {}
X_valid = {}
X_test = {}

y_train = {}
y_valid = {}
y_test = {}

for organ, indices in filtered_organ_indices.items():
    labels = mapped_filtered_organ_labels[organ]

    train_indices, test_indices, train_labels, test_labels = train_test_split(indices, labels, test_size=0.2, shuffle=True, stratify=labels)
    train_indices, valid_indices, train_labels, valid_labels = train_test_split(train_indices, train_labels, test_size=0.2 / 0.8, shuffle=True, stratify=train_labels)

    X_train[organ] = train_indices
    X_valid[organ] = valid_indices
    X_test[organ] = test_indices

    y_train[organ] = train_labels
    y_valid[organ] = valid_labels
    y_test[organ] = test_labels

#### Output split sizes

In [10]:
[len(x) for x in X_train.values()], \
[len(y) for y in y_train.values()], \
[len(x) for x in X_valid.values()], \
[len(y) for y in y_valid.values()], \
[len(x) for x in X_test.values()], \
[len(y) for y in y_test.values()]

([4450, 2769, 21606, 8448],
 [4450, 2769, 21606, 8448],
 [1484, 923, 7202, 2817],
 [1484, 923, 7202, 2817],
 [1484, 924, 7203, 2817],
 [1484, 924, 7203, 2817])

#### Check that data among splits do not intersect

In [11]:
for organ in ORGANS:
    train = set(X_train[organ])
    valid = set(X_valid[organ])
    test = set(X_test[organ])

    print(f"{train.intersection(valid)} {train.intersection(test)} {valid.intersection(test)}")

set() set() set()
set() set() set()
set() set() set()
set() set() set()


#### Output number of classes per modality for each split

In [12]:
[(o, len(set(y))) for o, y in y_train.items()], \
[(o, len(set(y))) for o, y in y_valid.items()], \
[(o, len(set(y))) for o, y in y_test.items()]

([('Fruit', 302), ('Stem', 146), ('Flower', 862), ('Leaf', 449)],
 [('Fruit', 302), ('Stem', 146), ('Flower', 862), ('Leaf', 449)],
 [('Fruit', 302), ('Stem', 146), ('Flower', 862), ('Leaf', 449)])

#### Check that each modality contains the same classes in each split

In [13]:
for organ in ORGANS:
    train = set(y_train[organ])
    valid = set(y_valid[organ])
    test = set(y_test[organ])

    print(f"lengths: {len(train)} {len(valid)} {len(test)}, intersections: {len(train.intersection(valid))} {len(train.intersection(test))} {len(valid.intersection(test))}")

lengths: 302 302 302, intersections: 302 302 302
lengths: 146 146 146, intersections: 146 146 146
lengths: 862 862 862, intersections: 862 862 862
lengths: 449 449 449, intersections: 449 449 449


# Save Unimodal Files

##### Saves them at path `CONFIG.get_unimodal_csv_file_path(split_name, modality)`, for example:
```python
CONFIG.get_unimodal_csv_file_path("train", "Flower")
```

In [14]:
save = False  # <<<<<<< SET TO TRUE TO SAVE


def save_unimodal_split(X, y, organ, split_name):
    df = pd.DataFrame({
        "Image": X[organ],
        "Label": y[organ],
    })

    csv_path = CONFIG.get_unimodal_csv_file_path(split_name, organ)
    os.makedirs(os.path.dirname(csv_path), exist_ok=True)

    df.to_csv(csv_path, index=False)


if save:
    for organ in ORGANS:
        save_unimodal_split(X_train, y_train, organ, "train")
        save_unimodal_split(X_valid, y_valid, organ, "validation")
        save_unimodal_split(X_test, y_test, organ, "test")

# Generate Multimodal Combinations

##### Saves them at path `CONFIG.get_multimodal_csv_file_path(split_name)`, for example:
```python
CONFIG.get_multimodal_csv_file_path("train")
```

In [15]:
generate = False  # <<<<<<< SET TO TRUE TO GENERATE


def generate_multimodal_combinations(split_name):
    paths = {o: CONFIG.get_unimodal_csv_file_path(split_name, o) for o in ORGANS}
    dfs = {o: pd.read_csv(path) for o, path in paths.items()}

    labels = np.unique(list(chain(*[df["Label"] for df in dfs.values()])))

    combinations = {
        **{o: [] for o in ORGANS},
        "Label": [],
    }

    for label in labels:
        organ_values = [(o, df[df["Label"] == label]["Image"]) for o, df in dfs.items()]
        organ_values = [(o, v) for o, v in organ_values if len(v) > 0]

        sorted_organ_values = sorted(organ_values, key=lambda x: len(x[1]), reverse=True)
        assert sorted_organ_values

        n_combinations = len(sorted_organ_values[0][1])
        assert n_combinations > 0

        organ_names = [x[0] for x in sorted_organ_values]
        organ_values = [x[1] for x in sorted_organ_values]

        organ_values = [np.random.permutation(values) for values in organ_values]
        organ_values = [np.resize(values, n_combinations) for values in organ_values]
        organ_values = [iter(values) for values in organ_values]

        organ_names_values = dict(zip(organ_names, organ_values))

        combinations["Label"].extend(np.repeat(label, n_combinations))

        for _ in range(n_combinations):
            for organ in ORGANS:
                if organ in organ_names_values:
                    value = next(organ_names_values[organ])
                    combinations[organ].append(value)
                else:
                    combinations[organ].append(None)

    df = pd.DataFrame(combinations).sample(frac=1)

    csv_path = CONFIG.get_multimodal_csv_file_path(split_name)
    os.makedirs(os.path.dirname(csv_path), exist_ok=True)

    df.to_csv(csv_path, index=False)


if generate:
    generate_multimodal_combinations("train")
    generate_multimodal_combinations("validation")
    generate_multimodal_combinations("test")