In [None]:
import multiprocessing
from pathlib import Path

import numpy as np
import pandas as pd
import tensorflow as tf
from PIL import Image
from sklearn.model_selection import train_test_split
import os
from tqdm.notebook import tqdm
from datetime import datetime

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"]="1"

In [None]:
input_dir = Path("data/input")

epochs = 16
batch_size = 64
image_size = 256
input_size = 224
num_classes = 40

In [None]:
def build_model(i):
            funcs=[tf.keras.applications.resnet_v2.ResNet50V2, tf.keras.applications.resnet_v2.ResNet152V2, tf.keras.applications.densenet.DenseNet121, tf.keras.applications.densenet.DenseNet169, tf.keras.applications.densenet.DenseNet201]
            modelfunc=funcs[i%len(funcs)]
            print(modelfunc)
            pretrained_model = modelfunc(include_top=False)

            h = pretrained_model.output
            h = tf.keras.layers.GlobalAveragePooling2D()(h)
            h = tf.keras.layers.Dropout(0.5)(h)
            h = tf.keras.layers.Dense(num_classes, activation="softmax")(h)
            return  tf.keras.Model(pretrained_model.input, h)

In [None]:
def loss_fn(y_true, y_pred):
    # クロスエントロピー
    ce = tf.keras.losses.categorical_crossentropy

    loss = 0

    # 正解との誤差
    loss += 0.8 * ce(y_true, y_pred)

    # "正解 - 1"との誤差
    mask = np.array([1] * (num_classes - 1) + [0], dtype="float32")
    loss += 0.1 * ce(mask * tf.roll(y_true, shift=-1, axis=1), y_pred)

    # "正解 + 1"との誤差
    mask = np.array([0] + [1] * (num_classes - 1), dtype="float32")
    loss += 0.1 * ce(mask * tf.roll(y_true, shift=1, axis=1), y_pred)
    return loss


def score_fn(y_true, y_pred):
    y_true = tf.math.argmax(y_true, axis=-1, output_type="int32")
    y_pred = tf.math.argmax(y_pred, axis=-1, output_type="int32")
    return tf.cast(tf.less_equal(tf.math.abs(y_true - y_pred), 1), "float32")

In [None]:
def read_image(path, training):
    x = Image.open(path).convert("RGB")

    # 画像を切り取る
    if training:
        # ランダム
        left = np.random.randint(0, image_size - input_size + 1)
        upper = np.random.randint(0, image_size - input_size + 1)
    else:
        # 中心
        left = (image_size - input_size) // 2
        upper = (image_size - input_size) // 2
    right = left + input_size
    lower = upper + input_size
    x = x.crop((left, upper, right, lower))

    x = np.array(x).astype("float32")
    if training:
        # 左右反転
        if np.random.rand() < 0.5:
            x = x[:, ::-1]

        # 回転
        k = np.random.randint(0, 4)
        x = np.rot90(x, k)

        # random erasing
        # https://github.com/yu4u/cutout-random-erasing
        p = 0.5
        s_l = 0.02
        s_h = 0.4
        r_1 = 0.3
        r_2 = 1 / 0.3
        v_l = 0
        v_h = 255
        pixel_level = True

        if np.random.rand() < p:
            while True:
                s = np.random.uniform(s_l, s_h) * input_size * input_size
                r = np.random.uniform(r_1, r_2)
                w = int(np.sqrt(s / r))
                h = int(np.sqrt(s * r))
                left = np.random.randint(0, input_size)
                top = np.random.randint(0, input_size)
                if left + w <= input_size and top + h <= input_size:
                    break
            if pixel_level:
                c = np.random.uniform(v_l, v_h, (h, w, 3))
            else:
                c = np.random.uniform(v_l, v_h)
            x[top : top + h, left : left + w, :] = c
    x = tf.keras.applications.resnet50.preprocess_input(x)
    return x


# クラスごとの画像枚数が一定になるようにオーバーサンプリングする
def oversampling(org_image_path_list, org_labels):
    label_counts = pd.Series(org_labels).value_counts()
    num_samples_per_class = label_counts.max()
    unique_label_list = label_counts.sort_index().index.to_list()

    image_path_list = []
    for target_label in unique_label_list:
        target_index_list = [
            i for i, label in enumerate(org_labels) if label == target_label
        ]
        target_image_path_list = [org_image_path_list[i] for i in target_index_list]
        num_iters = int(np.ceil(num_samples_per_class / len(target_index_list)))

        if num_iters == 1:
            image_path_list += target_image_path_list
        else:
            image_path_list += target_image_path_list * (num_iters - 1)
            image_path_list += np.random.permutation(target_image_path_list)[
                : num_samples_per_class - (num_iters - 1) * len(target_index_list)
            ].tolist()
    labels = np.array(
        [[label] * num_samples_per_class for label in unique_label_list],
        dtype=org_labels.dtype,
    ).ravel()
    return image_path_list, labels


class DataLoader(tf.keras.utils.Sequence):
    def __init__(self, image_path_list, year_list, batch_size, training):
        self.org_image_path_list = image_path_list

        # [1979, 2018] -> [0, 39]
        self.org_labels = np.array([int(i) - 1979 for i in year_list], dtype="int32")

        self.batch_size = batch_size
        self.training = training

        if self.training:
            self.image_path_list, self.labels = oversampling(
                self.org_image_path_list, self.org_labels
            )
            self.indices = np.random.permutation(len(self.image_path_list))
        else:
            self.image_path_list = self.org_image_path_list
            self.labels = self.org_labels
            self.indices = np.arange(len(self.image_path_list))

    def __len__(self):
        return int(np.ceil(len(self.image_path_list) / self.batch_size))

    def __getitem__(self, i):
        bs = i * self.batch_size
        be = (i + 1) * self.batch_size
        indices = self.indices[bs:be]

        images = np.array(
            [read_image(self.image_path_list[idx], self.training) for idx in indices]
        )
        labels = tf.keras.utils.to_categorical(self.labels[indices], num_classes)
        return images, labels

    def on_epoch_end(self):
        if self.training:
            self.image_path_list, self.labels = oversampling(
                self.org_image_path_list, self.org_labels
            )
            self.indices = np.random.permutation(len(self.image_path_list))

In [None]:
for i in tqdm(range(1000)):
    log_dir = Path("result")

    df = pd.read_csv(input_dir / "train_labels.csv", names=["name", "label"], header=None)
    df["image_path"] = df["name"].apply(lambda x: input_dir / "train_images" / x)
    df = df[["image_path", "label"]]

    train_df, valid_df = train_test_split(
        df, test_size=0.2, random_state=42, shuffle=True, stratify=df["label"]
    )
    train_df.to_csv("list_train.csv", header=None, index=None)
    valid_df.to_csv("list_valid.csv", header=None, index=None)

    train_gen = DataLoader(
        train_df["image_path"].to_list(), train_df["label"].to_list(), batch_size, True
    )
    valid_gen = DataLoader(
        valid_df["image_path"].to_list(), valid_df["label"].to_list(), batch_size, False
    )
    callbacks = [
        tf.keras.callbacks.TensorBoard(log_dir=str(log_dir), profile_batch=0),
        tf.keras.callbacks.ModelCheckpoint(
            filepath=str(log_dir / "weights-epoch{epoch:04}.h5"), save_weights_only=True
        ),
    ]

    model = build_model(i)
    model.compile(
        loss=loss_fn,
        optimizer=tf.keras.optimizers.Adam(lr=1e-4),
        metrics=["accuracy", score_fn],
    )
    model.fit_generator(
        train_gen,
        validation_data=valid_gen,
        epochs=epochs,
        callbacks=callbacks,
        workers=multiprocessing.cpu_count(),
        use_multiprocessing=True,
        shuffle=False,
    )

    
    df = pd.read_csv(
        input_dir / "sample_submission.csv", names=["name", "label"], header=None
    )
    image_path_list = df["name"].apply(lambda x: input_dir / "test_images" / x)
    dummy_label_list = [1979] * len(image_path_list)

    test_gen = DataLoader(image_path_list, dummy_label_list, batch_size, training=False)

    y_pred = model.predict_generator(test_gen, verbose=1)
    date= int(float(datetime.now().timestamp()))
    np.save("./result/classifier/{0}.npy".format(date), y_pred)
    del model
    tf.keras.backend.clear_session()