# 概要
類似商標画像検索に向けたモデル学習・最近傍探索のサンプルコードです。  
Nishika「[AI×商標：イメージサーチコンペティション](https://www.nishika.com/competitions/22/summary)」のデータで学習し、[Google Material Icons](https://fonts.google.com/icons)で画像検索結果を検証しています。

（Google CoLaboratoryで動作確認）

# 事前準備
Nishika「[AI×商標：イメージサーチコンペティション](https://www.nishika.com/competitions/22/summary)」のデータをダウンロードし、任意のディレクトリに保存します。  
zipファイルは解凍しておきます。

```
（保存例）
input
│   train.csv
│
└───apply_images
│   │   ...
│   
└───cite_images
　   │   ...
```


# 必要なライブラリのインストール

In [None]:
!pip install pytorch-metric-learning faiss-cpu timm
!pip install albumentations==0.4.6

In [None]:
# albumentationsのインストールを反映させるために、google colabを強制再起動
import os
os._exit(00)

# 学習設定

In [None]:
# backboneなどの学習設定

class CFG:
    exp_name = 'tutorial'   # 実験名(モデル出力フォルダ名)
    device = 'cuda'         # GPU設定
    seed = 82               # random seed
    print_freq = 50         # 学習状況のログ出力周期
    eval_freq = 100         # 学習時の検証データ スコアリング周期
    num_folds = 4           # 交差検証用の学習、検証データ分割数(fold)
    train_fold = [0,1,2,3]  # 学習するfold番号
    # train_fold = [0]
    
    epochs = 10             # 学習回数
    apex = True             # [option] Mixed Precision有効化（省メモリ化）
    gradient_accumulation_steps = 1 # [option] 指定した回数分lossを蓄積してbackword
                                    #（batchサイズを大きくしたいがメモリが乗らない,等の際に使用）
    
    backbone = 'resnet34d'  # backbone種類
    embed_dim = 512         # 特徴量の次元数

    img_size = 224          # モデルに入力する画像サイズ
    batch_size = 16         # batchサイズ
    num_workers = 4         # dataloaderのnum_workers
    
    lr = 1e-3               # 学習率
    warmup_steps_rate = 0.1 # [option] 学習率スケジューラのwarmups設定

# ライブラリのImport

In [None]:
import os
import gc
import json
import yaml
import copy
import time
import glob
import math
import shutil
import random
import itertools
import typing as tp
from pathlib import Path
from joblib import Parallel, delayed
from logging import getLogger, INFO, StreamHandler, FileHandler, Formatter

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import cv2
from PIL import Image, ImageOps
from tqdm.notebook import tqdm

from sklearn.model_selection import KFold, GroupKFold

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.optim import Optimizer, AdamW
from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingWarmRestarts

from pytorch_metric_learning import losses

import timm
import albumentations as A
from albumentations.pytorch import ToTensorV2

import faiss

import warnings
warnings.simplefilter('ignore')

# ディレクトリ設定

In [None]:
# Google Driveと接続
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# データ、モデル出力先 設定
INPUT_DIR = '/content/drive/MyDrive/Colab_Notebooks/nishika/patent/input/'
OUTPUT_DIR = f'/content/drive/MyDrive/Colab_Notebooks/nishika/patent/{CFG.exp_name}/'

if not os.path.exists(OUTPUT_DIR):
    os.makedirs(OUTPUT_DIR)

# 汎用的な関数

In [None]:
# ref : https://www.kaggle.com/code/yasufuminakama/pppm-deberta-v3-large-baseline-w-w-b-train

def get_logger(filename=OUTPUT_DIR+'train'):
    logger = getLogger(__name__)
    logger.setLevel(INFO)
    handler1 = StreamHandler()
    handler1.setFormatter(Formatter("%(message)s"))
    handler2 = FileHandler(filename=f"{filename}.log")
    handler2.setFormatter(Formatter("%(message)s"))
    logger.addHandler(handler1)
    logger.addHandler(handler2)
    return logger

LOGGER = get_logger()


def seed_everything(seed=42):
    LOGGER.info(f'seed_everything : {seed}')
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    
seed_everything(seed=CFG.seed)

In [None]:
# 最近傍探索

class FaissKNeighbors:
    def __init__(self, k=20):
        self.index = None
        self.d = None
        self.k = k

    def fit(self, X):
        X = X.copy(order='C')
        self.d = X.shape[1]
        X = X.astype(np.float32)
        self.index = faiss.IndexFlatIP(self.d) # cosine
        faiss.normalize_L2(X)
        self.index.add(X)

    def predict(self, X):
        X = X.copy(order='C')
        X = np.reshape(X, (-1, self.d))
        X = X.astype(np.float32)
        faiss.normalize_L2(X)
        distances, indices = self.index.search(X, k=self.k)
        if X.shape[0] == 1:
            return distances[0], indices[0]
        else:
            return distances, indices

# データ読み込み

In [None]:
df_train = pd.read_csv(INPUT_DIR + "train.csv")

In [None]:
print(df_train.shape)
df_train.head()

# 前処理

In [None]:
# アノテーションが重複する画像をグループ化

dic_id2label = {}
for i, (r_id, c_id) in tqdm(enumerate(zip(df_train['gid'].values, df_train['cite_gid'].values))):
    label = None
    if r_id in dic_id2label.keys():
        label = dic_id2label[r_id]
    elif c_id in dic_id2label.keys():
        label = dic_id2label[c_id]
        
    if label:
        dic_id2label[r_id] = label
        dic_id2label[c_id] = label
    else:
        dic_id2label[r_id] = i
        dic_id2label[c_id] = i

df_train['label'] = df_train['gid']
df_train['label'] = df_train['label'].map(dic_id2label)

In [None]:
# 余白除去
# ref : https://www.nishika.com/competitions/22/topics/169

def image_trimming(path):
    img = cv2.imread(path)
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    img2 = cv2.threshold(gray, 240, 255, cv2.THRESH_BINARY)[1]
    contours = cv2.findContours(img2, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)[0]
    x1 = [] #x座標の最小値
    y1 = [] #y座標の最小値
    x2 = [] #x座標の最大値
    y2 = [] #y座標の最大値
    for i in range(1, len(contours)):
        ret = cv2.boundingRect(contours[i])
        x1.append(ret[0])
        y1.append(ret[1])
        x2.append(ret[0] + ret[2])
        y2.append(ret[1] + ret[3])

    if x1:
        x1_min = min(x1)
        y1_min = min(y1)
        x2_max = max(x2)
        y2_max = max(y2)
        crop_img = img[y1_min:y2_max, x1_min:x2_max]
        # crop後の画像が小さすぎる場合は、cropしない
        if np.min([crop_img.shape[0],crop_img.shape[1]]) < 50:
            if (crop_img.shape[0]*crop_img.shape[1]) < (img.shape[0]*img.shape[1]*0.1):
                crop_img = img
    else:
        crop_img = img
    return crop_img

In [None]:
def trimming_and_save(input):
    input_path, output_path, output_dir = input
    crop_img = image_trimming(input_path)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    cv2.imwrite(output_path, crop_img)
    return

In [None]:
## 画像の余白除去し、crop_xx ディレクトリへ保存する

def preprocess(dir_name, img_paths):
    base_input_path = INPUT_DIR + f'{dir_name}/'
    input_paths = [base_input_path+path for path in img_paths]
    output_paths = []
    inputs = []
    for input_path in tqdm(input_paths):
        output_path = INPUT_DIR + f'trim_{dir_name}/' + '/'.join(input_path.split('/')[-2:])
        output_dir = INPUT_DIR + f'trim_{dir_name}/' + input_path.split('/')[-2]
        output_paths.append(output_path)
        inputs.append([input_path, output_path, output_dir])
    _ = Parallel(n_jobs=-1)(delayed(trimming_and_save)(input) for input in tqdm(inputs)) # 並列処理

In [None]:
def show_img_pair(img1, img2):
    plt.subplot(1, 2, 1)
    plt.imshow(img1)
    plt.title('Original')
    plt.xticks([])
    plt.yticks([])
    
    plt.subplot(1, 2, 2)
    plt.imshow(img2)
    plt.title('Triming')
    plt.xticks([])
    plt.yticks([])

    plt.tight_layout()
    plt.show()

In [None]:
preprocess('apply_images', df_train['path'].values)
preprocess('cite_images', df_train['cite_path'].values)

In [None]:
# 余白除去結果の確認
for img_path in df_train['path'][:5]:
    ori_img = Image.open(INPUT_DIR + f'apply_images/{img_path}').convert('RGB')
    trim_img = Image.open(INPUT_DIR + f'trim_apply_images/{img_path}').convert('RGB')
    show_img_pair(ori_img, trim_img)

# 交差検証

In [None]:
# ref : https://github.com/ghmagazine/kagglebook/blob/3d8509d1c1b41a765e3f4744ba1fb226188e2b15/ch05/ch05-01-validation.py#L143
# 学習データと評価データの分割
# sklearnのGroupKFoldではseed値によらず分割パターンが固定であるため、Kfoldを使用してseed値対応

unique_labels = df_train['label'].unique()
df_train['fold'] = -1

Fold = KFold(n_splits=CFG.num_folds, shuffle=True, random_state=CFG.seed)
for n, (tr_group_idx, val_group_idx) in enumerate(Fold.split(unique_labels)):
    tr_groups, val_groups = unique_labels[tr_group_idx], unique_labels[val_group_idx]
    df_train.loc[df_train['label'].isin(val_groups), 'fold'] = int(n)
df_train['fold'] = df_train['fold'].astype(int)
display(df_train.groupby('fold').size())

# Augmentation、画像整形 

In [None]:
IMAGENET_MEAN = [0.485, 0.456, 0.406]  # RGB
IMAGENET_STD = [0.229, 0.224, 0.225]  # RGB

def train_transform():
    transform = A.Compose(
        [
            # 画像整形
            A.LongestMaxSize(p=1.0, max_size=CFG.img_size, interpolation=2),
            A.PadIfNeeded(p=1.0, min_height=CFG.img_size, min_width=CFG.img_size, border_mode=0, value=[255, 255, 255]),
            
            # Augmentaion
            A.RandomResizedCrop(p=0.5, height=CFG.img_size, width=CFG.img_size, scale=[0.7, 1.0]),
            A.ImageCompression(p=0.5, quality_lower=1, quality_upper=50),
            A.RGBShift(p=0.5),
            A.ToGray(p=0.5),
            A.HorizontalFlip(p=0.5),
            
            A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
            ToTensorV2()
        ]
    )
    return transform

def test_transform():
    transform = A.Compose(
        [
            # 画像整形
            A.LongestMaxSize(p=1.0, max_size=CFG.img_size, interpolation=2),
            A.PadIfNeeded(p=1.0, min_height=CFG.img_size, min_width=CFG.img_size, border_mode=0, value=[255, 255, 255]),
            
            A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
            ToTensorV2()
        ]
    )
    return transform

In [None]:
# Augmentaion結果の確認
tmp = train_transform()
for _ in range(10):
    aug_img = tmp(image=np.array(ori_img))['image']
    show_img_pair(ori_img, aug_img.permute(1,2,0))

In [None]:
tmp = test_transform()
for _ in range(10):
    aug_img = tmp(image=np.array(ori_img))['image']
    show_img_pair(ori_img, aug_img.permute(1,2,0))

# Dataset

In [None]:
class TrademarkDataset(Dataset):
    def __init__(self, df, train_mode=False):
        self.path = df["path"].values
        self.cite_path = df["cite_path"].values
        self.label = df["label"].values
        self.transform = train_transform() if train_mode else test_transform()
        self.train_mode = train_mode

    def __len__(self):
        return len(self.path)

    def __getitem__(self, idx):
        path, cite_path = self.path[idx], self.cite_path[idx]
        img = self._read_image(INPUT_DIR + f'trim_apply_images/{path}')
        cite_img = self._read_image(INPUT_DIR + f'trim_cite_images/{cite_path}')
        imgs = [self.transform(image=img)['image']] + [self.transform(image=cite_img)['image']]
        if self.train_mode:
            imgs += [self.transform(image=img)['image']] + [self.transform(image=cite_img)['image']]
        return {'images':imgs, 'labels':self.label[idx]}
                
    def _read_image(self, path):
        img = Image.open(path)
        img = img.convert('RGB')
        img = np.array(img)
        return img

In [None]:
class TrademarkDatasetForInference(Dataset):
    def __init__(self, df, path_col):
        self.path = df[path_col].values
        self.transform = test_transform()
        if 'cite' in path_col:
            self.dir_name = 'cite_images'
        else:
            self.dir_name = 'apply_images'

    def __len__(self):
        return len(self.path)

    def __getitem__(self, idx):
        path = self.path[idx]
        img = self._read_image(INPUT_DIR + f'trim_{self.dir_name}/{path}')
        imgs = self.transform(image=img)['image']
        
        return {'images':imgs}
                
    def _read_image(self, path):
        img = Image.open(path)
        img = img.convert('RGB')
        img = np.array(img)
        return img

# Model

In [None]:
# ref : https://github.com/ChristofHenkel/kaggle-landmark-2021-1st-place/blob/main/models/ch_mdl_dolg_efficientnet.py#L162

def gem(x, p=3, eps=1e-6):
    return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1./p)

class GeM(nn.Module):
    def __init__(self, p=3, eps=1e-6, p_trainable=False):
        super(GeM,self).__init__()
        if p_trainable:
            self.p = Parameter(torch.ones(1)*p)
        else:
            self.p = p
        self.eps = eps

    def forward(self, x):
        ret = gem(x, p=self.p, eps=self.eps)   
        return ret
    def __repr__(self):
        return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + ', ' + 'eps=' + str(self.eps) + ')'

In [None]:
# ref：https://www.kaggle.com/code/ttahara/rerun-seti-e-t-resnet18d-baseline?scriptVersionId=68360235&cellId=17
# ref：https://github.com/lyakaap/ISC21-Descriptor-Track-1st/blob/6e3a51be54e1aae8f41a1703bc75da7c143b5c53/exp/v83.py#L103
# ref : https://github.com/ChristofHenkel/kaggle-landmark-2021-1st-place/blob/main/models/ch_mdl_dolg_efficientnet.py

class ImageEmbeddingModel(nn.Module):
    def __init__(self):
        super().__init__()
        
        if hasattr(timm.models, CFG.backbone):
            base_model = timm.create_model(
                CFG.backbone,
                num_classes=0,
                pretrained=True,
                in_chans=3,
                global_pool="",
                drop_path_rate=0.0)
            in_features = base_model.num_features
        else:
            raise NotImplementedError

        self.backbone = base_model
        self.global_pool = GeM(p=3, p_trainable=False)
        self.head = nn.Sequential(
                nn.Linear(in_features, CFG.embed_dim,  bias=False),
                nn.BatchNorm1d(CFG.embed_dim),
                torch.nn.PReLU()
            )

    def forward(self, x):
        x = self.backbone(x)
        x = self.global_pool(x)
        x = x[:,:,0,0]
        x = self.head(x)
        x = F.normalize(x)
        return x

# 学習用関数

In [None]:
# ref : https://www.kaggle.com/code/yasufuminakama/pppm-deberta-v3-large-baseline-w-w-b-train?scriptVersionId=90923371&cellId=31

# ====================================================
# Helper functions
# ====================================================
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)


def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (remain %s)' % (asMinutes(s), asMinutes(rs))

In [None]:
# ref : https://www.kaggle.com/code/yasufuminakama/pppm-deberta-v3-large-baseline-w-w-b-train?scriptVersionId=90923371&cellId=31

def train_w_eval_fn(fold, train_loader, valid_loader, model, criterion, optimizer, epoch, scheduler, device, best_score):
    model.train()
    scaler = torch.cuda.amp.GradScaler(enabled=CFG.apex)
    losses = AverageMeter()
    start = end = time.time()
    global_step = 0
    for step, (batch) in enumerate(train_loader):
        images, labels = batch['images'], batch['labels']
        images = torch.cat([image for image in images], dim=0)
        labels = torch.tile(labels, dims=(4,))
        images = images.to(device)
        labels = labels.to(device)
        batch_size = labels.size(0)
        with torch.cuda.amp.autocast(enabled=CFG.apex):
            embeddings = model(images)
        loss = criterion(embeddings, labels)
        if CFG.gradient_accumulation_steps > 1:
            loss = loss / CFG.gradient_accumulation_steps
        losses.update(loss.item(), batch_size)
        scaler.scale(loss).backward()
        if (step + 1) % CFG.gradient_accumulation_steps == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
            global_step += 1
            scheduler.step()
        end = time.time()
        if step % CFG.print_freq == 0 or step == (len(train_loader)-1):
            print('Epoch: [{0}][{1}/{2}] '
                  'Elapsed {remain:s} '
                  'Loss: {loss.val:.4f}({loss.avg:.4f}) '
                  'LR: {lr:.8f}  '
                  .format(epoch+1, step, len(train_loader), 
                          remain=timeSince(start, float(step+1)/len(train_loader)),
                          loss=losses,
                          lr=scheduler.get_lr()[0]))
        if (step % CFG.eval_freq == 0 or step == (len(train_loader)-1))&(step > 0):
            best_score = eval_fn(fold, valid_loader, model, device, best_score)
            model.train()
    elapsed = time.time() - start
    LOGGER.info(f'Epoch {epoch+1} - avg_train_loss: {losses.avg:.4f}  time: {elapsed:.0f}s')
    return best_score


def eval_fn(fold, valid_loader, model, device, best_score):
    print('Eval Start')
    model.eval()
    start = end = time.time()
    print('Make Embedding')
    score = get_score(model, fold, device)
    elapsed = time.time() - start
    LOGGER.info(f'Score: {score:.4f}  time: {elapsed:.0f}s')
    if score > best_score:
        LOGGER.info(f'***** Save Best {score:.4f} Model *****')
        torch.save({'model': model.state_dict()},
                   OUTPUT_DIR+f"fold{fold}_best.pth")
        best_score = score
    else:
        LOGGER.info(f'Score: {score:.4f}. No improve from {best_score:.4f} Model')
    return best_score


# train, validation間でペア関係を判定できているか確認
def get_score(model, fold, device):
    valid_folds = df_train[df_train['fold'] == fold].reset_index(drop=True)    
    cite_dataset = TrademarkDatasetForInference(df_train, 'cite_path')
    query_dataset = TrademarkDatasetForInference(valid_folds, 'path')
    cite_loader = DataLoader(cite_dataset,
                              batch_size=CFG.batch_size,
                              shuffle=False,
                              num_workers=CFG.num_workers, pin_memory=True, drop_last=False)
    query_loader = DataLoader(query_dataset,
                              batch_size=CFG.batch_size,
                              shuffle=False,
                              num_workers=CFG.num_workers, pin_memory=True, drop_last=False)
    cite_embeddings = inference_fn(cite_loader, model, device)
    query_embeddings = inference_fn(query_loader, model, device)
    
    kn = FaissKNeighbors()
    kn.fit(cite_embeddings)
    dist, idxs = kn.predict(query_embeddings)
    cnt = 0
    for n, idx in tqdm(enumerate(idxs)):
        cites = [df_train['cite_gid'].iloc[i] for i in idx]
        if valid_folds['cite_gid'].iloc[n] in cites:
            cnt += 1
    score = cnt/len(valid_folds)
    return score
    
    
def inference_fn(data_loader, model, device):
    preds = []
    model.eval()
    model.to(device)
    tk0 = tqdm(data_loader, total=len(data_loader))
    for batch in tk0:
        images = batch['images'].to(device)
        with torch.no_grad():
            embeddings = model(images)
        preds.append(embeddings.to('cpu').numpy())
    predictions = np.concatenate(preds)
    return predictions

In [None]:
# 学習率スケジューラ
# ref : https://github.com/huggingface/transformers/blob/d0acc9537829e7d067edbb791473bbceb2ecf056/src/transformers/optimization.py#L104

def get_cosine_schedule_with_warmup(
    optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1
):
    """
    Create a schedule with a learning rate that decreases following the values of the cosine function between the
    initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
    initial lr set in the optimizer.
    Args:
        optimizer ([`~torch.optim.Optimizer`]):
            The optimizer for which to schedule the learning rate.
        num_warmup_steps (`int`):
            The number of steps for the warmup phase.
        num_training_steps (`int`):
            The total number of training steps.
        num_cycles (`float`, *optional*, defaults to 0.5):
            The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
            following a half-cosine).
        last_epoch (`int`, *optional*, defaults to -1):
            The index of the last epoch when resuming training.
    Return:
        `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
    """

    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))

    return LambdaLR(optimizer, lr_lambda, last_epoch)

In [None]:
# ====================================================
# train loop
# ====================================================
def train_loop(folds, fold):
    
    LOGGER.info(f"========== fold: {fold} training ==========")

    device = CFG.device
    # ====================================================
    # loader
    # ====================================================
    train_folds = folds[folds['fold'] != fold].reset_index(drop=True)
    valid_folds = folds[folds['fold'] == fold].reset_index(drop=True)
    
    train_dataset = TrademarkDataset(train_folds, train_mode=True)
    valid_dataset = TrademarkDataset(valid_folds, train_mode=False)

    train_loader = DataLoader(train_dataset,
                              batch_size=CFG.batch_size,
                              shuffle=True,
                              num_workers=CFG.num_workers, pin_memory=True, drop_last=True)
    valid_loader = DataLoader(valid_dataset,
                              batch_size=CFG.batch_size,
                              shuffle=False,
                              num_workers=CFG.num_workers, pin_memory=True, drop_last=False)

    # ====================================================
    # model & optimizer
    # ====================================================
    model = ImageEmbeddingModel()
    model.to(device)
    
    optimizer = AdamW(model.parameters(), lr=CFG.lr)
    
    # ====================================================
    # scheduler
    # ====================================================
    num_train_steps = int(len(train_folds) / CFG.batch_size * CFG.epochs)
    num_warmup_steps = int(num_train_steps * CFG.warmup_steps_rate)
    
    scheduler = get_cosine_schedule_with_warmup(optimizer,
                                                num_warmup_steps=num_warmup_steps,
                                                num_training_steps=num_train_steps)
    
    # ====================================================
    # loop
    # ====================================================
    criterion = losses.MultiSimilarityLoss()
    criterion = losses.CrossBatchMemory(criterion,
                                        memory_size=1024,
                                        embedding_size=CFG.embed_dim)
    best_score = 0.
    for epoch in range(CFG.epochs):
        best_score = train_w_eval_fn(fold, train_loader, valid_loader, model, criterion, optimizer, epoch, scheduler, device, best_score)
    torch.cuda.empty_cache()
    gc.collect()

# 学習

In [None]:
for fold in range(CFG.num_folds):
    if fold in CFG.train_fold:
        train_loop(df_train, fold)

# 検証

In [None]:
# Material Iconをダウンロード
!git clone https://github.com/google/material-design-icons.git

In [None]:
# pngファイルを一つのフォルダに集約
SAVE_DIR = './material-design-icons_png/'
if not os.path.exists(SAVE_DIR):
    os.makedirs(SAVE_DIR)

file_lists = glob.glob('./material-design-icons/png/*.png')
for fd_path, sb_folder, sb_file in os.walk('./material-design-icons/png/'):
    for fil in sb_file:
        path = fd_path + '/' + fil
        if '/'.join(path.split('/')[:-1]).endswith('materialicons/48dp/2x'):
            shutil.copy(path, SAVE_DIR+f"{path.split('/')[-1]}")

In [None]:
material_fig_paths = glob.glob('./material-design-icons_png/*.png')
material_fig_paths[:5]

In [None]:
def get_material_image(path):
    img = Image.open(path)
    mode = img.mode
    if mode == 'P':
        img = np.array(img)
        img = np.array(img>0*255)
        img = Image.fromarray(img).convert('RGB')
        img = ImageOps.invert(img)
        img = np.array(img)
    elif mode == 'LA':
        img = np.array(img)
        img = Image.fromarray(img[:,:,-1]).convert('RGB')
        img = ImageOps.invert(img)
        img = np.array(img)
    else:
        img = np.array(img)
        img = Image.fromarray(img[:,:,-1]).convert('RGB')
        img = ImageOps.invert(img)
        img = np.array(img)
    return img

In [None]:
class MaterialDatasetForInference(Dataset):
    def __init__(self, path):
        self.path = path
        self.transform = test_transform()

    def __len__(self):
        return len(self.path)

    def __getitem__(self, idx):
        path = self.path[idx]
        img = get_material_image(path)
        imgs = self.transform(image=img)['image']        
        return {'images':imgs}

In [None]:
def get_material_embed(model, device, path):
    material_dataset = MaterialDatasetForInference(path)
    material_loader = DataLoader(material_dataset,
                              batch_size=CFG.batch_size,
                              shuffle=False,
                              num_workers=CFG.num_workers, pin_memory=True, drop_last=False)
    material_embeddings = inference_fn(material_loader, model, device)
    return material_embeddings

In [None]:
# 画像特徴量の抽出
material_embeddings = []
for fold in range(CFG.num_folds):
    if fold in CFG.train_fold:
        model = ImageEmbeddingModel()
        model.load_state_dict(torch.load(OUTPUT_DIR+f"fold{fold}_best.pth")['model'])
        model.to(CFG.device)
        material_embeddings.append(get_material_embed(model, CFG.device, material_fig_paths))
material_embeddings = np.mean(material_embeddings, axis=0)

In [None]:
# 最近傍探索
kn = FaissKNeighbors()
kn.fit(material_embeddings)
dist, idxs = kn.predict(material_embeddings)

In [None]:
# 検索結果を確認

visualize_num = 10
for n, (d, idx) in enumerate(zip(dist[:visualize_num], idxs[:visualize_num])):
    plt.figure(figsize=(10,10))
    plt.subplot(551)
    img = get_material_image(material_fig_paths[n])
    plt.imshow(img)
    plt.title('Query')
    for i, _idx in enumerate(idx):
        plt.subplot(5,5,i+6)
        img = get_material_image(material_fig_paths[_idx])
        plt.imshow(img)
        plt.title(np.round(d[i], 3))
    plt.tight_layout()
    plt.show()