In [0]:
try:
    from google.colab import drive
    IN_COLAB = True
except:
    IN_COLAB = False

print('IN_COLAB', IN_COLAB)

if IN_COLAB:
    drive.mount('/content/drive')
    ROOT_DIR = '/content/drive/My Drive/archive/Captured'
else:
    ROOT_DIR = 'Captured'

In [0]:
from google.colab import auth
auth.authenticate_user()
import gspread
from oauth2client.client import GoogleCredentials
gc = gspread.authorize(GoogleCredentials.get_application_default())

In [0]:
# %tensorflow_version 1.x
# %matplotlib notebook
import tensorflow as tf
from tensorflow.keras import datasets, layers, models
import matplotlib.pyplot as plt

import os
import numpy as np
import glob
import time
import pickle

In [0]:
import cv2
from google.colab.patches import cv2_imshow 

In [0]:
print(tf.__version__)

In [0]:
BATCH_SIZE = 32

In [0]:
IMG_HEIGHT = 9 * 20
IMG_HEIGHT2 = IMG_HEIGHT // 3
IMG_WIDTH = 16 * 20

print(IMG_HEIGHT, IMG_HEIGHT2, IMG_WIDTH)

In [0]:
def read_labels_from_drive(anime_name):
    worksheet = gc.open(anime_name).sheet1
    rows = worksheet.get_all_values()
    labels = {}
    for row in rows:
        labels[row[0]] = int(row[1])
    return labels

In [0]:
def list_filenames_from_drive(anime_name):
    filenames = []
    img_dir = os.path.join(ROOT_DIR, anime_name, '*.png')
    for filename in glob.glob(img_dir):
        filenames.append(filename)
    img_dir = os.path.join(ROOT_DIR, anime_name, '*.jpg')
    for filename in glob.glob(img_dir):
        filenames.append(filename)
    return filenames

In [0]:
def get_labels_by_filenames(anime_name, filenames):
    all_labels = read_labels_from_drive(anime_name)
    labels = []
    for filename in filenames:
        filename = os.path.split(filename)[-1]
        labels.append(all_labels[filename])
    return labels

In [0]:
def get_dataset(anime_name, limit=None, reload_labels=False):
    print('Load {}'.format(anime_name))
    start_time = time.time()

    cache_path = '{}/data/{}.cache'.format(ROOT_DIR, anime_name)
    if os.path.exists(cache_path):
        with open(cache_path, 'rb') as f:
            images, labels, filenames = pickle.load(f)
    else:
        images = []
        labels = []

        filenames = list_filenames_from_drive(anime_name)
        if limit:
            filenames = filenames[:limit]
        labels = get_labels_by_filenames(anime_name, filenames)
        for i, filename in enumerate(filenames):
            if i % 100 == 0:
                print(i)
            img = cv2.imread(filename, cv2.IMREAD_GRAYSCALE)
            img = cv2.resize(img, (IMG_WIDTH, IMG_HEIGHT))
            img = img[IMG_HEIGHT-IMG_HEIGHT2:, :, np.newaxis]
            images.append(img)

        with open(cache_path, 'wb') as f:
            pickle.dump((images, labels, filenames), f)

    if reload_labels:
        labels = get_labels_by_filenames(anime_name, filenames)
        with open(cache_path, 'wb') as f:
            pickle.dump((images, labels, filenames), f)

    print('Get {} images from {} in {:.2f} s'.format(len(images), anime_name, time.time()-start_time))

    return images, labels, filenames

In [0]:
Kaguya_sama2 = '輝夜姬想讓人告白～天才們的戀愛頭腦戰～ 第二季' # 1555
Bookworm2 = '小書痴的下剋上：為了成為圖書管理員不擇手段！第二季' # 167
SakuraWars = '新櫻花大戰 動畫' # 120
Kakushigoto = '隱瞞之事' # 475
HameFura = '轉生成女性向遊戲只有毀滅END的壞人大小姐' # 48
KinmozaPrettyDays = '黃金拼圖 Pretty Days' # 39

In [0]:
train_images = []
train_labels = []

In [0]:
# images, labels, _ = get_dataset(
#     Kaguya_sama2,
# )
# train_images += images
# train_labels += labels

In [0]:
images, labels, _ = get_dataset(
    SakuraWars,
)
train_images += images
train_labels += labels

In [0]:
DS_SIZE = len(train_images)
print('DS_SIZE', DS_SIZE)

In [0]:
train_images = np.array(train_images)
train_labels = np.array(train_labels)

In [0]:
print(train_images.shape)
print(train_labels.shape)

In [0]:
def show_batch(image_batch, label_batch, count=None):
    if count is None:
        count = (3, 3)

    plt.figure(
        # dpi = 200,
        figsize=(20, count[0]*2),
        facecolor='w',
    )
    for n in range(count[0]*count[1]):
        if n >= len(image_batch):
            break
        ax = plt.subplot(count[0], count[1], n+1)
        plt.imshow(image_batch[n][:, :, 0], cmap='gray')
        plt.title(str(label_batch[n]))
        plt.axis('off')
        ax.autoscale(enable=True) 

In [0]:
show_batch(train_images, train_labels)

In [0]:
print('input_shape', IMG_HEIGHT2, IMG_WIDTH)

In [0]:
model = models.Sequential()
model.add(layers.Conv2D(64, (3, 3), activation='relu', input_shape=(IMG_HEIGHT2, IMG_WIDTH, 1)))
model.add(layers.Conv2D(128, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(256, (3, 3), activation='relu'))
model.add(layers.Conv2D(256, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(256, (3, 3), activation='relu'))
model.add(layers.Conv2D(256, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Flatten())
model.add(layers.Dense(4096, activation='relu'))
model.add(layers.Dense(256, activation='relu'))
model.add(layers.Dense(2, activation='softmax'))

model.summary()

In [0]:
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

history = model.fit(
    train_images, train_labels,
    epochs=10,
    validation_split=0.2,
)

In [0]:
plt.figure(facecolor='w')
plt.plot(history.history['accuracy'], label='accuracy')
plt.plot(history.history['val_accuracy'], label = 'val_accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.ylim([0.5, 1])
plt.legend(loc='lower right')

In [0]:
class Tester:
    def __init__(self, anime_name):
        test_images, test_labels, filenames = get_dataset(
            anime_name,
        )
        self.test_images = np.array(test_images)
        test_labels = np.array(test_labels)

        test_loss, test_acc = model.evaluate(self.test_images, test_labels, verbose=2)

        predict = model.predict(self.test_images)
        self.predict = [1 if t[1] > t[0] else 0 for t in predict]
    
        self.error_images = []
        self.error_labels = []
        f = open('{}_error.txt'.format(anime_name), 'w', encoding='utf8')
        for i in range(len(self.test_images)):
            if self.predict[i] != test_labels[i]:
                f.write('{}\n'.format(filenames[i]))
                self.error_images.append(self.test_images[i])
                self.error_labels.append(self.predict[i])
        f.close()
        print('error count', len(self.error_images))

    def show_result(self):
        show_batch(self.test_images, self.predict, (5, 3))

    def show_error(self):
        show_batch(self.error_images, self.error_labels)

In [0]:
kaguya_sama2 = Tester(Kaguya_sama2)
kaguya_sama2.show_result()

In [0]:
kaguya_sama2.show_error()

In [0]:
bookworm2 = Tester(Bookworm2)
bookworm2.show_result()

In [0]:
bookworm2.show_error()

In [0]:
sakuraWars = Tester(SakuraWars)
sakuraWars.show_result()

In [0]:
sakuraWars.show_error()

In [0]:
kakushigoto = Tester(Kakushigoto)
kakushigoto.show_result()

In [0]:
kakushigoto.show_error()

In [0]:
hameFura = Tester(HameFura)
hameFura.show_result()

In [0]:
hameFura.show_error()

In [0]:
kinmozaPrettyDays = Tester(KinmozaPrettyDays)
kinmozaPrettyDays.show_result()

In [0]:
kinmozaPrettyDays.show_error()

In [0]:
test = Tester('Test')
test.show_result()