### Install package

In [1]:
!pip install pytorch-adapt

Collecting pytorch-adapt
  Downloading pytorch_adapt-0.0.61-py3-none-any.whl (137 kB)
[K     |████████████████████████████████| 137 kB 12.6 MB/s 
Collecting torchmetrics
  Downloading torchmetrics-0.7.2-py3-none-any.whl (397 kB)
[K     |████████████████████████████████| 397 kB 44.7 MB/s 
Collecting pytorch-metric-learning>=1.1.0
  Downloading pytorch_metric_learning-1.2.0-py3-none-any.whl (107 kB)
[K     |████████████████████████████████| 107 kB 51.1 MB/s 
Collecting pyDeprecate==0.3.*
  Downloading pyDeprecate-0.3.2-py3-none-any.whl (10 kB)
Installing collected packages: pyDeprecate, torchmetrics, pytorch-metric-learning, pytorch-adapt
Successfully installed pyDeprecate-0.3.2 pytorch-adapt-0.0.61 pytorch-metric-learning-1.2.0 torchmetrics-0.7.2


### Import packages

In [2]:
import torch
from tqdm import tqdm

from pytorch_adapt.containers import Models, Optimizers
from pytorch_adapt.datasets import DataloaderCreator, get_mnist_mnistm
from pytorch_adapt.hooks import DANNHook
from pytorch_adapt.models import Discriminator, mnistC, mnistG
from pytorch_adapt.utils.common_functions import batch_to_device
from pytorch_adapt.validators import IMValidator

### Create datasets and dataloaders

In [3]:
datasets = get_mnist_mnistm(["mnist"], ["mnistm"], folder=".", download=True)
dc = DataloaderCreator(batch_size=32, num_workers=2)
dataloaders = dc(**datasets)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw

Downloading https://public.boxcloud.com/d/1/b1!WEHliV_vLyrC0Iiquf6P1gtU32kmk80D6Ov9RE7eWZdUfzhIRCo8_ylMfkyii60mT-Z7Ko9vlyWManWnHcCWGm6ppZL1tzzIYtuTUtfHRANxlLWvgpdEjOQpnWDuuuichWX7Tl-PY52APDt1s7rM_aYC0GqgECCwr7kVAGeg06jv30kgQAn7bhrMYMC1BlFIX0f1UFuRWV0ZnpsWsmQug2kXCvVpnmVzs_C6WBQDsghegWa4qZ0V-2xwTDO4NA2Ociy4yQmw5PlKy64Synmsp_bWPv55hKMvxDhgEBPveK8_fPWKKMBgHoKVLTsCVbg1SbaQQI930_pQ8k7aI2nf1tt5jvSPCHQ2fLuxUsPrvzcn5-h8NuiSv4lG8Zbovz-C5OSxMjp45jjMnK2OKQaAdK0k2JzdSi2-i2zA9qV3CEpsDLJu9JGHImuNzk0VRoKNfPcMCNGQpkJlcGYGxHHpNHi-g4ny6PDZiXhx7tvbLk1L9urU-iQNcGqPRR16WTrRGkUKVoU-1VjQjTCBKCXIhVBm_HqbYXitwjCdSUnOOCJRNM5fO2kVBP72EK5xhhbOdP0BqBkeXb-OELliGKFF5NPHJilggAqO4pG99fCRIy84EP1BV_mZF7hppi3eY58KOcyWg9TmUXp8V8QCr52qBkXpBTAY1927hBhx92wvVHus-CBfHYYnh1rBUxrEqjpK2WWpmM9laEanseb3ZOUdLQBsN3Afff4N0p4va-6zwA1zokC9_afmGz3h7AYzzZEcmp_0E31B2esFcruiaZWvS3PBds1CqZmhph3f6gGqmjO8KbAZEZ2htNn3wW6-4Tn2zaff5Q_JCMQTOISarLoOaZwSX7fEmGxybEn9xJrG7wHB4AUucq8mtUas

  0%|          | 0/134178716 [00:00<?, ?it/s]

100%|██████████| 68007/68007 [00:17<00:00, 3817.40it/s]


### Create models, optimizers, hook, and validator

In [4]:
device = torch.device("cuda")

G = mnistG(pretrained=True).to(device)
C = mnistC(pretrained=True).to(device)
D = Discriminator(in_size=1200, h=256).to(device)
models = Models({"G": G, "C": C, "D": D})
optimizers = Optimizers((torch.optim.Adam, {"lr": 0.0001}))
optimizers.create_with(models)
optimizers = list(optimizers.values())

hook = DANNHook(optimizers)
validator = IMValidator()

Downloading: "https://cornell.box.com/shared/static/tdx0ts24e273j7mf3r2ox7a12xh4fdfy" to /root/.cache/torch/hub/checkpoints/mnistG-68ee7945.pth


  0%|          | 0.00/161k [00:00<?, ?B/s]

Downloading: "https://cornell.box.com/shared/static/j4zrogronmievq1csulrkai7zjm27gcq" to /root/.cache/torch/hub/checkpoints/mnistC-ac7b5a13.pth


  0%|          | 0.00/1.31M [00:00<?, ?B/s]

### Train and evaluate

In [5]:
for epoch in range(2):

    # train loop
    models.train()
    for data in tqdm(dataloaders["train"]):
        data = batch_to_device(data, device)
        _, loss = hook({**models, **data})

    # eval loop
    models.eval()
    logits = []
    with torch.no_grad():
        for data in tqdm(dataloaders["target_train"]):
            data = batch_to_device(data, device)
            logits.append(C(G(data["target_imgs"])))
        logits = torch.cat(logits, dim=0)

    # validation score
    score = validator(target_train={"logits": logits})
    print(f"\nEpoch {epoch} score = {score}\n")

100%|██████████| 1843/1843 [01:13<00:00, 24.99it/s]
100%|██████████| 1844/1844 [00:27<00:00, 65.96it/s]



Epoch 0 score = 1.0747219324111938



100%|██████████| 1843/1843 [01:12<00:00, 25.31it/s]
100%|██████████| 1844/1844 [00:27<00:00, 65.88it/s]


Epoch 1 score = 1.31474369764328




