# Pytorch Adaptを用いたDANN（Domain Adversarial Neural Networks）の実装



In [1]:
!pip install -q pytorch-adapt

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m158.2/158.2 kB[0m [31m3.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m419.6/419.6 kB[0m [31m15.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m111.4/111.4 kB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[?25h

In [3]:
import torch
from tqdm import tqdm

from pytorch_adapt.containers import Models, Optimizers
from pytorch_adapt.datasets import DataloaderCreator, get_mnist_mnistm
from pytorch_adapt.hooks import DANNHook
from pytorch_adapt.models import Discriminator, mnistC, mnistG
from pytorch_adapt.utils.common_functions import batch_to_device
from pytorch_adapt.validators import IMValidator


In [4]:
# データのダウンロード
datasets = get_mnist_mnistm(["mnist"],["mnistm"], folder=".", download=True)


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 96075720.57it/s]


Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 38504670.64it/s]

Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz





Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 25738529.49it/s]

Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz



100%|██████████| 4542/4542 [00:00<00:00, 2044486.88it/s]


Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw

Downloading https://public.boxcloud.com/d/1/b1!NegopJ_HO2jtwhRaUEEE_CdCjI8rnCdODdv-YTxIPLjAQmvQneclqnL4j-V1XKFK-PgD5Q3MSuTpqz7VZNjeGhZm_A_yJLunicLnL76P5q8qLCdVFftbNNr734TAE05OUyq1xXxqjc4WoOjiyQWvC8jJePSUd9fslUNNqgZ0cUrwupxI8omtNkiipXoWASP6wjFHv5yOCwbPxUpMU43rw_FfixmFI5w4pfMWUBBXfe6FrLQo0Xk6ywBIVcKFWg3x4lL7ux6d3WO4KFK1MVXeeyfSkk_Z12oeTvtTOHAIjc4sf2ISCx0mrGXCKRrc5Yvzv-ASsN2QeKMSeLac8AtNtk9f7hDgvvgAP12byQ95rJEAgv7pp0sTzZTTAbaoHydTtQNgwQDIbY0DqCP5A9nEz6B94bXmaBIP-rvAMRsYYDp51tFT8vqkL0vbLxGnDZy_UNCMgK1BlswCrE5buRyv_-OCe30k8CiKxKP6LOcv6jxkZe3ZlbrI0mqnKoANzXP3pcqW9A02g3t4J0Zi5d9lZYGmw8fuflDQeDpURP9o0hFwKL0-ok5kES_SmBmx8t1_PjjxwMpjHzAgHJ24aiojWo2WdSEF551KKZ7U_vqDBeGODv9fxSq5WjB2S-nVUTbKV9EOpWoZzLy8zRmTMT1kmvXZ7vHAL7942Lg41yTTANFEu5ND9kEvqBUpL3ro0zlLBuBY-GvyAKNrbqR_qKkHcp7ECvng-FyZDttcGRhQkylxvnwK68U3ecgQe6v8O2FXtoJ9RJBlo0Sl4epTYcwC7gTVQii4q1kDBf1I2PwQ9VQ7nba_IAzAaG3Y5zLvj85XPSieoz5PAATFJAiyXpPMkXWZILGW-mdT_Pts8d0yn7J-MMy6XkGZSelO

100%|██████████| 134178716/134178716 [00:15<00:00, 8944557.50it/s] 
100%|██████████| 68007/68007 [00:10<00:00, 6738.88it/s]


In [5]:
# データローダーの作成
dc = DataloaderCreator(batch_size=32)
dataloaders = dc(**datasets)

In [7]:
dc

<pytorch_adapt.datasets.dataloader_creator.DataloaderCreator at 0x7f64b4bc34f0>

In [6]:
dataloaders

{'src_train': <torch.utils.data.dataloader.DataLoader at 0x7f64b4bc3cd0>,
 'src_val': <torch.utils.data.dataloader.DataLoader at 0x7f64b4bc2050>,
 'target_train': <torch.utils.data.dataloader.DataLoader at 0x7f64b4bc3700>,
 'target_val': <torch.utils.data.dataloader.DataLoader at 0x7f64b4bc3ee0>,
 'train': <torch.utils.data.dataloader.DataLoader at 0x7f64b4bc3940>}

# モデルの定義

In [8]:
# デバイス情報の取得
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [9]:
device

device(type='cuda')

In [16]:
, optimizers
# 共通の特徴抽出器を定義
G = mnistG(pretrained=False).to(device)

# クラス分類のための全結合層
C = mnistC(pretrained=False).to(device)

# 敵対的学習のための全結合層を定義
D = Discriminator(in_size=1200, # 入力次元数
                  h=256
                  ).to(device)

models = Models({"G": G, "C":C, "D":D})

optimizers = Optimizers((torch.optim.Adam, {"lr": 0.001}))

# モデルのハイパラをオプティマイザに登録
optimizers.create_with(models)
optimizers = list(optimizers.values())

hook = DANNHook(optimizers)

validator = IMValidator()

In [12]:
G

MNISTFeatures(
  (conv1): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(32, 48, kernel_size=(5, 5), stride=(1, 1))
  (fc): Identity()
)

In [15]:
C

Classifier(
  (net): Sequential(
    (0): Linear(in_features=1200, out_features=256, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=256, out_features=128, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=128, out_features=10, bias=True)
  )
)

In [17]:
D

Discriminator(
  (net): Sequential(
    (0): Linear(in_features=1200, out_features=256, bias=True)
    (1): ReLU()
    (2): Linear(in_features=256, out_features=256, bias=True)
    (3): ReLU()
    (4): Linear(in_features=256, out_features=1, bias=True)
  )
)

In [18]:
models

G: MNISTFeatures(
  (conv1): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(32, 48, kernel_size=(5, 5), stride=(1, 1))
  (fc): Identity()
)
C: Classifier(
  (net): Sequential(
    (0): Linear(in_features=1200, out_features=256, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=256, out_features=128, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=128, out_features=10, bias=True)
  )
)
D: Discriminator(
  (net): Sequential(
    (0): Linear(in_features=1200, out_features=256, bias=True)
    (1): ReLU()
    (2): Linear(in_features=256, out_features=256, bias=True)
    (3): ReLU()
    (4): Linear(in_features=256, out_features=1, bias=True)
  )
)

In [19]:
optimizers

[Adam (
 Parameter Group 0
     amsgrad: False
     betas: (0.9, 0.999)
     capturable: False
     differentiable: False
     eps: 1e-08
     foreach: None
     fused: None
     lr: 0.001
     maximize: False
     weight_decay: 0
 ),
 Adam (
 Parameter Group 0
     amsgrad: False
     betas: (0.9, 0.999)
     capturable: False
     differentiable: False
     eps: 1e-08
     foreach: None
     fused: None
     lr: 0.001
     maximize: False
     weight_decay: 0
 ),
 Adam (
 Parameter Group 0
     amsgrad: False
     betas: (0.9, 0.999)
     capturable: False
     differentiable: False
     eps: 1e-08
     foreach: None
     fused: None
     lr: 0.001
     maximize: False
     weight_decay: 0
 )]

In [20]:
hook

DANNHook(
  (hook): ChainHook(
    (hooks): (OptimizerHook(
      optimizers=[Adam (
      Parameter Group 0
          amsgrad: False
          betas: (0.9, 0.999)
          capturable: False
          differentiable: False
          eps: 1e-08
          foreach: None
          fused: None
          lr: 0.001
          maximize: False
          weight_decay: 0
      ), Adam (
      Parameter Group 0
          amsgrad: False
          betas: (0.9, 0.999)
          capturable: False
          differentiable: False
          eps: 1e-08
          foreach: None
          fused: None
          lr: 0.001
          maximize: False
          weight_decay: 0
      ), Adam (
      Parameter Group 0
          amsgrad: False
          betas: (0.9, 0.999)
          capturable: False
          differentiable: False
          eps: 1e-08
          foreach: None
          fused: None
          lr: 0.001
          maximize: False
          weight_decay: 0
      )]
      weighter=MeanWeighter(
        wei

In [21]:
validator

IMValidator(
  required_data=['target_train']
  weights={'entropy': 1, 'diversity': 1}
  (entropy): EntropyValidator(required_data=['target_train'])
  (diversity): DiversityValidator(required_data=['target_train'])
)

# 学習の実行

In [23]:
num_epoch = 3

for epoch in range(num_epoch):
    models.train()

    # ソースとターゲットをまとめたデータローダーであるtrainを使って学習
    for data in tqdm(dataloaders["train"]):
        data = batch_to_device(data, device)
        _, loss = hook({**models, **data})

    models.eval()

    # クラス分類時の負の対数尤度を格納
    logits = []

    with torch.no_grad():
        for data in tqdm(dataloaders["target_train"]):
            data = batch_to_device(data, device)
            logits.append(C(G(data["target_imgs"])))
        logits = torch.cat(logits, dim=0)

    # 負の対数尤度のリストを更新
    score = validator(target_train={"logits": logits})

    print(f"\nEpoch {epoch} score = {score}\n")


100%|██████████| 1843/1843 [01:30<00:00, 20.44it/s]
100%|██████████| 1844/1844 [00:32<00:00, 57.41it/s]



Epoch 0 score = 1.300287902355194



100%|██████████| 1843/1843 [01:13<00:00, 24.91it/s]
100%|██████████| 1844/1844 [00:31<00:00, 57.97it/s]



Epoch 1 score = 1.4996307492256165



100%|██████████| 1843/1843 [01:13<00:00, 25.15it/s]
100%|██████████| 1844/1844 [00:31<00:00, 58.69it/s]


Epoch 2 score = 1.6046777367591858




