<a href="https://colab.research.google.com/github/aruaru0/pytorch-tests/blob/main/MNIST_MetricLearning_FAISS.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install umap-learn
!pip install pytorch-metric-learning
!pip install faiss-cpu

Collecting umap-learn
  Downloading umap-learn-0.5.5.tar.gz (90 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m90.9/90.9 kB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting pynndescent>=0.5 (from umap-learn)
  Downloading pynndescent-0.5.11-py3-none-any.whl (55 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m55.8/55.8 kB[0m [31m5.8 MB/s[0m eta [36m0:00:00[0m
Building wheels for collected packages: umap-learn
  Building wheel for umap-learn (setup.py) ... [?25l[?25hdone
  Created wheel for umap-learn: filename=umap_learn-0.5.5-py3-none-any.whl size=86832 sha256=21a03066fd1b944ac329878c7f307e32a074e9f501ea948a63b21ec161483000
  Stored in directory: /root/.cache/pip/wheels/3a/70/07/428d2b58660a1a3b431db59b806a10da736612ebbc66c1bcc5
Successfully built umap-learn
Installing collected packages: pynndescent, umap-learn
Successfully installed pynndescent-0.5.11 umap-learn-0.5.5
Colle

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import matplotlib.pyplot as plt
import numpy as np
from tqdm.notebook import tqdm
from sklearn.metrics import accuracy_score

train_dataset = torchvision.datasets.MNIST(root="data",
                                           train=True,
                                           transform=torchvision.transforms.ToTensor(),
                                           download=True)
valid_dataset = torchvision.datasets.MNIST(root="data",
                                           train=False,
                                           transform=torchvision.transforms.ToTensor(),
                                           download=True)

batch_size = 64
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 47262312.71it/s]


Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 40018399.02it/s]


Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 23384643.52it/s]


Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 14345277.69it/s]


Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw



In [3]:
from pytorch_metric_learning import losses

In [4]:
class MyModel(nn.Module):
  def __init__(self, input_size):
    super(MyModel, self).__init__()
    self.size = input_size*input_size
    self.fc1 = nn.Linear(self.size, 1024)
    self.fc2 = nn.Linear(1024, 256)
  def forward(self, x):
    x = x.view(-1, self.size)
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    return x

device = "cuda" if torch.cuda.is_available() else "cpu"
model = MyModel(28).to(device)
model

MyModel(
  (fc1): Linear(in_features=784, out_features=1024, bias=True)
  (fc2): Linear(in_features=1024, out_features=256, bias=True)
)

In [5]:
metric = losses.ArcFaceLoss(num_classes=10, embedding_size=256, scale=64, margin=32).to(device)
optimizer = torch.optim.Adam(
    [{'params': model.parameters()}, {'params': metric.parameters()}],
    lr=0.001)

In [6]:
def do_train(model, device, loader, criterion, optimizer):
  model.train()
  tot_loss = 0.0
  for images, labels in tqdm(loader, desc="train"):
    images, labels = images.to(device), labels.to(device)

    optimizer.zero_grad()
    outputs = model(images)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
    tot_loss += loss.detach().item()

  tot_loss /= len(loader)
  return tot_loss

def do_valid(model, device, loader, criterion):
  model.eval()
  tot_loss = 0.0
  with torch.no_grad():
    for images, labels in tqdm(loader, desc="valid"):
      images, labels = images.to(device), labels.to(device)
      outputs = model(images)
      loss = criterion(outputs, labels)
      tot_loss += loss.detach().item()
  tot_loss /= len(loader)
  return tot_loss

In [7]:
num_epochs = 10
for epoch in range(num_epochs):
  print(f'[EPOCH {epoch+1}]')
  train_loss = do_train(model, device, train_loader, metric, optimizer)
  valid_loss = do_valid(model, device, valid_loader, metric)
  print(f"--> train loss {train_loss}, valid loss {valid_loss}")

[EPOCH 1]


train:   0%|          | 0/938 [00:00<?, ?it/s]

valid:   0%|          | 0/157 [00:00<?, ?it/s]

--> train loss 5.349368308251823, valid loss 2.510154813427588
[EPOCH 2]


train:   0%|          | 0/938 [00:00<?, ?it/s]

valid:   0%|          | 0/157 [00:00<?, ?it/s]

--> train loss 2.1482210636433305, valid loss 2.1418662109241895
[EPOCH 3]


train:   0%|          | 0/938 [00:00<?, ?it/s]

valid:   0%|          | 0/157 [00:00<?, ?it/s]

--> train loss 1.575575936103869, valid loss 1.8635885642383838
[EPOCH 4]


train:   0%|          | 0/938 [00:00<?, ?it/s]

valid:   0%|          | 0/157 [00:00<?, ?it/s]

--> train loss 1.2874109143940624, valid loss 1.8497358464007339
[EPOCH 5]


train:   0%|          | 0/938 [00:00<?, ?it/s]

valid:   0%|          | 0/157 [00:00<?, ?it/s]

--> train loss 1.0678269722340434, valid loss 1.569766599231717
[EPOCH 6]


train:   0%|          | 0/938 [00:00<?, ?it/s]

valid:   0%|          | 0/157 [00:00<?, ?it/s]

--> train loss 0.9082480346030417, valid loss 1.5596077285156582
[EPOCH 7]


train:   0%|          | 0/938 [00:00<?, ?it/s]

valid:   0%|          | 0/157 [00:00<?, ?it/s]

--> train loss 0.7792837668679832, valid loss 1.7201417499656628
[EPOCH 8]


train:   0%|          | 0/938 [00:00<?, ?it/s]

valid:   0%|          | 0/157 [00:00<?, ?it/s]

--> train loss 0.6899205521978087, valid loss 1.5696976854674562
[EPOCH 9]


train:   0%|          | 0/938 [00:00<?, ?it/s]

valid:   0%|          | 0/157 [00:00<?, ?it/s]

--> train loss 0.6317136644363094, valid loss 1.7117486108903015
[EPOCH 10]


train:   0%|          | 0/938 [00:00<?, ?it/s]

valid:   0%|          | 0/157 [00:00<?, ?it/s]

--> train loss 0.4990174776054393, valid loss 1.50004615240377


In [7]:
model.eval()

features = None
classes  = None

for images, labels in tqdm(valid_loader):
  with torch.no_grad():
    images = images.to(device)
    outputs = model(images)
  # print(outputs.shape)
  if classes is None:
    classes = labels.cpu()
  else:
    classes = torch.cat((classes, labels.cpu()))

  if features is None:
    features = outputs.cpu()
  else:
    features = torch.cat((features, outputs.cpu()))

import umap
umap = umap.UMAP(n_components=2, random_state=42)
X_umap = umap.fit_transform(features)

plt.scatter(X_umap[:, 0], X_umap[:, 1], c=classes)
plt.show()

In [8]:
from pytorch_metric_learning.distances import CosineSimilarity

# 最初のバッチについて特徴量を作成
images, labels = valid_loader.__iter__().__next__()

model.eval()
with torch.no_grad():
  images = images.to(device)
  features = model(images)

# 距離を計算
ret = CosineSimilarity()(features, features).cpu()

# 最初の文字とのコサイン距離が近いものを抽出する（閾値0.9）
print(f"label = {labels[0]}")
for i in range(64):
  if ret[0][i] > 0.9 :
    print("-"*40, end="")
  print(f"score = {float(ret[0][i]):.3f} label = {int(labels[i])}")

label = 7
----------------------------------------score = 1.000 label = 7
score = 0.032 label = 2
score = 0.000 label = 1
score = 0.000 label = 0
score = 0.002 label = 4
score = 0.000 label = 1
score = 0.002 label = 4
score = 0.000 label = 9
score = 0.041 label = 5
score = 0.001 label = 9
score = 0.019 label = 0
score = 0.009 label = 6
score = 0.002 label = 9
score = 0.017 label = 0
score = 0.049 label = 1
score = 0.026 label = 5
score = 0.000 label = 9
----------------------------------------score = 0.996 label = 7
score = 0.005 label = 3
score = 0.003 label = 4
score = 0.000 label = 9
score = 0.028 label = 6
score = 0.008 label = 6
score = 0.033 label = 5
score = 0.001 label = 4
score = 0.005 label = 0
----------------------------------------score = 0.959 label = 7
score = 0.003 label = 4
score = 0.014 label = 0
score = 0.109 label = 1
score = 0.005 label = 3
score = 0.106 label = 1
score = 0.006 label = 3
score = 0.004 label = 4
----------------------------------------score = 0.983 

# FAISS

In [9]:
import faiss
import random

In [10]:
dim = 256
nlist = 10
m = 32
nbits = 8

quantizer = faiss.IndexFlatIP(dim)

# index = faiss.IndexIVFFlat(quantizer, dim, nlist, faiss.METRIC_INNER_PRODUCT)
index = faiss.IndexIVFPQ(quantizer, dim, nlist, m, nbits, faiss.METRIC_INNER_PRODUCT)

In [11]:
# 訓練データの特徴量とラベルを全て抜き出す

model.eval()

features = None
classes  = None

for images, labels in tqdm(train_loader):
  with torch.no_grad():
    images = images.to(device)
    outputs = model(images)
    outputs = (outputs.transpose(1,0)/outputs.norm(dim=1)).transpose(1,0)

  if classes is None:
    classes = labels.cpu()
  else:
    classes = torch.cat((classes, labels.cpu()))

  if features is None:
    features = outputs.cpu()
  else:
    features = torch.cat((features, outputs.cpu()))

  0%|          | 0/938 [00:00<?, ?it/s]

In [12]:
# 圧縮のために一部のデータで特徴量の分布を学習させる
train_data = np.array([v for v in  features if random.random() < 0.01])
index.train(train_data)

In [13]:
# 10000個毎に追加：細切れで追加可能なので、後から追加も可能
batch_size = 10000
for i in range(0, len(classes), batch_size):
    input_vecs = []
    input_ids = []
    for item_id, vec in zip(classes[i:i+batch_size], features[i:i+batch_size]):
        input_vecs.append(vec)
        input_ids.append(item_id)
    input_vecs = np.array(input_vecs, dtype=np.float32)
    input_ids = np.array(input_ids, dtype=np.int64)
    index.add_with_ids(input_vecs, input_ids)

# 作成したインデックスを保存
faiss.write_index(index, "features.index")

In [14]:
# インデックスを使って近傍検索を行う

# 検証データから最初の64個を取り出して特徴量を生成
images, labels = valid_loader.__iter__().__next__()
model.eval()
with torch.no_grad():
  images = images.to(device)
  features = model(images)
  features = (features.transpose(1,0)/features.norm(dim=1)).transpose(1,0)

# 保存したインデックスを読み込む
index = faiss.read_index("features.index")

D, I = index.search(features.cpu().numpy(), 3) # 近傍３個を取得（D: distance, I: index）

for label, idx, dist in zip(labels, I, D):
  print(f"正解={label}, 検索結果(上位3位) = {idx}, 距離={dist}")

正解=7, 検索結果(上位3位) = [7 7 7], 距離=[1.0041672 1.0022324 1.0019482]
正解=2, 検索結果(上位3位) = [2 2 2], 距離=[1.0064819 1.004201  1.0031328]
正解=1, 検索結果(上位3位) = [1 1 1], 距離=[1.010136  1.0074596 1.0072712]
正解=0, 検索結果(上位3位) = [0 0 0], 距離=[1.0066127 1.0063363 1.005796 ]
正解=4, 検索結果(上位3位) = [4 4 4], 距離=[1.0225395 1.010642  1.009774 ]
正解=1, 検索結果(上位3位) = [1 1 1], 距離=[1.0099753 1.0080909 1.0079058]
正解=4, 検索結果(上位3位) = [4 4 4], 距離=[1.0009947 0.999285  0.998502 ]
正解=9, 検索結果(上位3位) = [9 9 9], 距離=[1.0102962 1.0046943 1.0007346]
正解=5, 検索結果(上位3位) = [5 5 5], 距離=[0.6879167  0.66920096 0.66766775]
正解=9, 検索結果(上位3位) = [9 9 9], 距離=[1.0056764 1.0051448 1.0047344]
正解=0, 検索結果(上位3位) = [0 0 0], 距離=[1.0072564 1.006715  1.006556 ]
正解=6, 検索結果(上位3位) = [6 6 6], 距離=[1.001517  1.0005835 1.0005009]
正解=9, 検索結果(上位3位) = [9 9 9], 距離=[1.005662  1.0055563 1.005211 ]
正解=0, 検索結果(上位3位) = [0 0 0], 距離=[1.011124  1.0101511 1.0082476]
正解=1, 検索結果(上位3位) = [1 1 1], 距離=[1.005921  1.0054464 1.0048993]
正解=5, 検索結果(上位3位) = [5 5 5], 距離=[1.0027235 1.0024531 