In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
import time
import random
import glob
from PIL import Image
from collections import defaultdict
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from torch.utils.data import Dataset, DataLoader, DistributedSampler
from torch.multiprocessing.spawn import spawn
import torchvision.models as models
from torchvision.datasets import VOCSegmentation, CIFAR10, MNIST
from torchvision.io import read_image
from torchvision.transforms import v2

In [4]:
class TripletMNIST(Dataset):
    def __init__(self, root, train=True, transform=None, download=True):
        self.mnist = MNIST(
            root=root, 
            train=train, 
            transform=transform, 
            download=download)
        # ラベルごとのインデックス一覧を作成
        self.label_to_idxs = defaultdict(list)
        for idx, (_, label) in enumerate(self.mnist):
            self.label_to_idxs[label].append(idx)

    def __len__(self):
        return len(self.mnist)

    def __getitem__(self, index):
        img_a, label_a = self.mnist[index]  # anchor
        # positive: 同じラベルからランダムに別のインデックスを選択
        pos_index = index
        while pos_index == index:
            pos_index = random.choice(self.label_to_idxs[label_a])
        img_p, _ = self.mnist[pos_index]
        # negative: ランダムに異なるラベルを選んでからサンプル
        neg_label = random.choice([l for l in self.label_to_idxs.keys() if l != label_a])
        neg_index = random.choice(self.label_to_idxs[neg_label])
        img_n, _ = self.mnist[neg_index]

        return (img_a, img_p, img_n), []
    
# トランスフォーム
transform = v2.Compose([
    v2.ToTensor(),              # [0,255]→[0.,1.]
    v2.Normalize((0.1307,), (0.3081,)),  # MNIST 平均・標準偏差
])


class EmbeddingNet(nn.Module):
    def __init__(self, embedding_dim=32):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, 5),   # → 32×24×24
            nn.ReLU(),
            nn.MaxPool2d(2),       # → 32×12×12
            nn.Conv2d(32, 64, 5),  # → 64×8×8
            nn.ReLU(),
            nn.MaxPool2d(2),       # → 64×4×4
        )
        self.fc = nn.Sequential(
            nn.Linear(64*4*4, 256),
            nn.ReLU(),
            nn.Linear(256, embedding_dim)
        )

    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        out = self.fc(x)
        return F.normalize(out, p=2, dim=1)  # L2 正規化

In [5]:
# data
train_ds = TripletMNIST(root="../data", train=True, transform=transform)
train_loader = DataLoader(train_ds, batch_size=512, shuffle=True, num_workers=4, pin_memory=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# model
model = EmbeddingNet(embedding_dim=32)
model = nn.DataParallel(model.to(device))

# others
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.TripletMarginLoss(margin=1.0, p=2)

num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for (anc, pos, neg), _ in train_loader:
        anc, pos, neg = anc.to(device), pos.to(device), neg.to(device)
        emb_a = model(anc)
        emb_p = model(pos)
        emb_n = model(neg)
        loss = criterion(emb_a, emb_p, emb_n)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{num_epochs} - Loss: {avg_loss:.4f}")

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:01<00:00, 6.87MB/s]


Extracting ../data/MNIST/raw/train-images-idx3-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ../data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 221kB/s]


Extracting ../data/MNIST/raw/train-labels-idx1-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:00<00:00, 1.68MB/s]


Extracting ../data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 2.46MB/s]


Extracting ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw

Epoch 1/10 - Loss: 0.1292
Epoch 2/10 - Loss: 0.0235
Epoch 3/10 - Loss: 0.0152
Epoch 4/10 - Loss: 0.0100
Epoch 5/10 - Loss: 0.0072
Epoch 6/10 - Loss: 0.0053
Epoch 7/10 - Loss: 0.0048
Epoch 8/10 - Loss: 0.0044
Epoch 9/10 - Loss: 0.0030
Epoch 10/10 - Loss: 0.0029


In [18]:
# テスト用に 2 つの画像を取って埋め込み距離を比較
model.eval()
with torch.no_grad():
    (x1, _, _), _ = train_ds[0]     # ラベル 1 のサンプル
    (x2, _, _), _ = train_ds[10]     # 同じラベル or 別ラベル
    emb1 = model(x1.unsqueeze(0).to(device))
    emb2 = model(x2.unsqueeze(0).to(device))
    dist = torch.norm(emb1 - emb2, p=2).item()
    print("Embedding Distance:", dist)

Embedding Distance: 1.0376266241073608
