In [1]:
pip install pytorch-adapt

Note: you may need to restart the kernel to use updated packages.


In [8]:
import torch
import torchvision.transforms as T
from torchvision.datasets import MNIST

from pytorch_adapt.datasets import (
    MNISTM,
    CombinedSourceAndTargetDataset,
    SourceDataset,
    TargetDataset,
)
from pytorch_adapt.hooks import DANNHook
from pytorch_adapt.models import Classifier, Discriminator, MNISTFeatures
from pytorch_adapt.utils.common_functions import batch_to_device
from pytorch_adapt.utils.constants import IMAGENET_MEAN, IMAGENET_STD
from pytorch_adapt.utils.transforms import GrayscaleToRGB

In [3]:
mnist_T = T.Compose(
    [
        T.Resize(32),
        T.ToTensor(),
        GrayscaleToRGB(),
        T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
    ]
)

mnistm_T = T.Compose(
    [
        T.ToTensor(),
        T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
    ]
)

src_train = MNIST(root=".", train=True, download=True, transform=mnist_T)
src_val = MNIST(root=".", train=False, transform=mnist_T)
target_train = MNISTM(root=".", train=True, transform=mnistm_T)
target_val = MNISTM(root=".", train=False, transform=mnistm_T)

  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [4]:
src_train = SourceDataset(src_train)
src_val = SourceDataset(src_val)
target_train = TargetDataset(target_train)
target_val = TargetDataset(target_val)

train_set = CombinedSourceAndTargetDataset(src_train, target_train)
batch_size = 64
num_workers = 2

dataloader = torch.utils.data.DataLoader(
    train_set, batch_size=batch_size, num_workers=num_workers
)

In [5]:
device = torch.device("cuda")

G = MNISTFeatures().to(device)
C = Classifier(num_classes=10, in_size=1200, h=256).to(device)
D = Discriminator(in_size=1200, h=256).to(device)
models = {"G": G, "C": C, "D": D}

G_opt = torch.optim.Adam(G.parameters(), lr=0.0001)
C_opt = torch.optim.Adam(C.parameters(), lr=0.0001)
D_opt = torch.optim.Adam(D.parameters(), lr=0.0001)
opts = [G_opt, C_opt, D_opt]

hook = DANNHook(opts)

In [9]:
for i, data in enumerate(dataloader):
    data = batch_to_device(data, device)
    loss, loss_components = hook({}, {**models, **data})
    print(loss)
    break

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


{'total_loss': {'src_domain_loss': 0.6909892559051514, 'target_domain_loss': 0.700066089630127, 'c_loss': 2.305485248565674, 'total': 1.2321802377700806}}
