### Install package

In [1]:
!pip install pytorch-adapt[lightning]

Collecting pytorch-adapt[lightning]
  Downloading pytorch_adapt-0.0.61-py3-none-any.whl (137 kB)
[K     |████████████████████████████████| 137 kB 4.4 MB/s 
[?25hCollecting pytorch-metric-learning>=1.1.0
  Downloading pytorch_metric_learning-1.2.0-py3-none-any.whl (107 kB)
[K     |████████████████████████████████| 107 kB 37.2 MB/s 
Collecting torchmetrics
  Downloading torchmetrics-0.7.2-py3-none-any.whl (397 kB)
[K     |████████████████████████████████| 397 kB 35.9 MB/s 
Collecting pytorch-lightning
  Downloading pytorch_lightning-1.5.10-py3-none-any.whl (527 kB)
[K     |████████████████████████████████| 527 kB 41.0 MB/s 
Collecting setuptools==59.5.0
  Downloading setuptools-59.5.0-py3-none-any.whl (952 kB)
[K     |████████████████████████████████| 952 kB 35.9 MB/s 
[?25hCollecting pyDeprecate==0.3.1
  Downloading pyDeprecate-0.3.1-py3-none-any.whl (10 kB)
Collecting future>=0.17.1
  Downloading future-0.18.2.tar.gz (829 kB)
[K     |████████████████████████████████| 829 kB 38.

### Import packages

In [2]:
import pytorch_lightning as pl
import torch

from pytorch_adapt.adapters import DANN
from pytorch_adapt.containers import Models, Optimizers
from pytorch_adapt.datasets import DataloaderCreator, get_mnist_mnistm
from pytorch_adapt.frameworks.lightning import Lightning
from pytorch_adapt.frameworks.utils import filter_datasets
from pytorch_adapt.models import Discriminator, mnistC, mnistG
from pytorch_adapt.validators import IMValidator

### Create datasets and dataloaders

In [3]:
datasets = get_mnist_mnistm(["mnist"], ["mnistm"], folder=".", download=True)
dc = DataloaderCreator(batch_size=32, num_workers=2)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw

Downloading https://public.boxcloud.com/d/1/b1!E3B_VWmKeTJ4niQHyudHHygxv-qXo-LTwfU0YisAyLRGc4cYUv3zPFdYyHuRVrciZ09urG1jgT194832izjwxuO78bByDnAIjqxhFzbqj8uOuyPselWm-YQPa5snS23CXivorkczN0LHGAfN97oseGAuA7jabDeJA8E0IGhdKkoFg3KvWRmDrMV4FT3sjOYKnnmLf16dB2rOTRczXCnXMbt7_GnlmrNuR8zSx-vnZmslzsn1HJ4LN1RSOXKRUI3f40zzFFYDRBGMN0BjrlyPI7tVIcm4JuI236l5q3JbCSqWK7n3QFKZLKGVsjuMbFs5iE58UxR9MwinyfL7i53IRRd91iOvpDtXZ81bzbAPFFnDnFWwbUGGZvOou7jSW3GHxS_kz5BbNzSNYW7rlSUPRju7hNWuG1GbF53TO1NzDeukTj7iLfSO27R0_h08jzsdIW3txW3XRYWlVawhDAg1RMSUxwf-h8z5N_A4hl3wd9qGYUg3IFDiYeADFgvOui-0y8xcyEuJ2GCJ1wbLom1yxFcl1ZxgCSBNuOPaZnjdqoQv9ZIpSaw7bja_5wce3xfIuyrtDcmUiqJJ1_Z3EfIH0DXFbuBC2vbc42i6FQiy-v6Msz1xQPlKbqh1WMLOp7n_Es_7oHrDEUcrcHhDe06Ew3DBNziuiYxs-k24s18dohBHFFx_byjtPwRkMnLNAdkq3YlCP5f8iG40bCnSL9QxuK_YGar3maMHfROQxLHLJNlGYo1N3CiJGX-AmJgvifOHVN957cvH2bZwS8UBqifYEZjHC80ftTAgPazeiq7nLKB_bOi3YRgT3flfGh_ROCdzS2kXV9EaINYNrhcqc56ud44-XoCsoy6oAQjutxw9OJ4qoYLogWAHMpq6

  0%|          | 0/134178716 [00:00<?, ?it/s]

100%|██████████| 68007/68007 [00:16<00:00, 4169.63it/s]


### Create models, optimizers, hook, and validator

In [4]:
G = mnistG(pretrained=True)
C = mnistC(pretrained=True)
D = Discriminator(in_size=1200, h=256)
models = Models({"G": G, "C": C, "D": D})
optimizers = Optimizers((torch.optim.Adam, {"lr": 0.0001}))

adapter = DANN(models=models, optimizers=optimizers)
validator = IMValidator()
dataloaders = dc(**filter_datasets(datasets, validator))
train_loader = dataloaders.pop("train")

L_adapter = Lightning(adapter, validator=validator)
trainer = pl.Trainer(gpus=1, max_epochs=2)

Downloading: "https://cornell.box.com/shared/static/tdx0ts24e273j7mf3r2ox7a12xh4fdfy" to /root/.cache/torch/hub/checkpoints/mnistG-68ee7945.pth


  0%|          | 0.00/161k [00:00<?, ?B/s]

Downloading: "https://cornell.box.com/shared/static/j4zrogronmievq1csulrkai7zjm27gcq" to /root/.cache/torch/hub/checkpoints/mnistC-ac7b5a13.pth


  0%|          | 0.00/1.31M [00:00<?, ?B/s]

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


### Train and evaluate

In [5]:
trainer.fit(L_adapter, train_loader, list(dataloaders.values()))

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Missing logger folder: /content/lightning_logs

  | Name   | Type       | Params
--------------------------------------
0 | models | ModuleDict | 756 K 
1 | misc   | ModuleDict | 0     
--------------------------------------
756 K     Trainable params
0         Non-trainable params
756 K     Total params
3.024     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]