# ArcFace

In [None]:
import os
from pathlib import Path
from argparse import ArgumentParser
import shutil
import datetime

import tqdm
import torch
import numpy as np
from attrdict import AttrDict
from sklearn.manifold import TSNE

from trainers import get_trainer
from dataset import get_dataset
from models import get_model

from utils import setup_logger, read_yaml, increment_path, save_yaml, save_hostname

In [None]:
arcface_config = dict(
    model=dict(
        name="resnet50",
        version="pytorch/vision:v0.10.0",
        pretrained=True
    ),
    loss=dict(
        name="arcface",
        params=dict(
            num_dim=512,
            s=30.0,
            m=0.50,
            easy_margin=False, 
            size_average=None, 
            ignore_index=-100, 
            reduce=None, 
            reduction="mean"
        )
    ),
    init=dict(
        name="xavier_uniform", 
        params=dict(
            gain=1.0,
        )
    ),
    optimizer=dict(
        name="sgd", 
        params=dict(
            lr=1.e-1, 
            momentum=0.9, 
            weight_decay=5.0e-4
        )
    ),
    train=dict(
        dset_type="clf", 
        epoch=5, 
        batch_size=200, 
        num_workers=10
    ),
    device="0"
)

dset_config=dict(
    name="mnist", 
    classes=10, 
    root="../storage", 
    download=True, 
    transforms=dict(
        resize=256,
        RGB=True
    ),
    target_transform=None, 
    num_channel=1
)

arcface_config = AttrDict(arcface_config)
dset_config    = AttrDict(dset_config)

In [None]:
dset = get_dataset(arcface_config, dset_config, mode="train")
valid_dset = get_dataset(arcface_config, dset_config, mode="valid")
model = get_model(arcface_config, dset_config)
trainer = get_trainer(arcface_config, dset_config)
trainer.train(dataset=dset, valid_dataset=valid_dset, model=model)

In [None]:
model = trainer.model

In [None]:
valid_dloader = torch.utils.data.DataLoader(
    valid_dset,
    batch_size=100,
    shuffle=True,
    num_workers=1,
)
device = trainer.device


acc = 0.
lt_embedding = list()
lt_labels = list()
with torch.inference_mode():
    for x, labels in tqdm.tqdm(valid_dloader):
        x = x.to(device)
        output = model(x)
        output = output.cpu()
        lt_embedding.append(output.numpy())
        lt_labels.append(labels.numpy())
        acc += (output.argmax(dim=1) == labels).sum() / len(labels)

print(acc)

In [None]:
# t-sneの追加
array_embedding = np.vstack(lt_embedding).astype(float)
array_labels    = np.hstack(lt_labels).astype(int)
X_embedding = TSNE(n_components=2, learning_rate='auto', init='random').fit_transform(array_embedding)

In [None]:
import matplotlib.pyplot as plt
plt.scatter(X_embedding[:, 0], X_embedding[:, 1],
            c=array_labels, cmap='jet')
plt.colorbar()
plt.show()