In [6]:
import torch
import torch.nn as nn
import timm
from tqdm import tqdm

from pytorch_adapt.containers import Models, Optimizers
from pytorch_adapt.datasets import DataloaderCreator, get_mnist_mnistm, SourceDataset, TargetDataset, CombinedSourceAndTargetDataset
from pytorch_adapt.hooks import DANNHook
from pytorch_adapt.models import Discriminator, mnistC, mnistG, Classifier
from pytorch_adapt.utils.common_functions import batch_to_device
from pytorch_adapt.validators import IMValidator

In [1]:
from Dataset import OfficeHome
from util import utils

In [2]:
domains = ["Art", "Clipart", "Product", "Real World"]

In [3]:
src_train_dataloader = utils.get_train_loader(domains[0])
src_test_dataloader = utils.get_test_loader(domains[0])
tgt_train_dataloader = utils.get_train_loader(domains[1])
tgt_test_dataloader = utils.get_test_loader(domains[1])

In [7]:
device = torch.device("cuda")

# G = mnistG(pretrained=True).to(device)
G = timm.create_model('mobilenetv3_small_075', pretrained=True).to(device)
G.classifier = nn.Identity()
C = Classifier(in_size=1024, num_classes=10).to(device)
D = Discriminator(in_size=1024, h=256).to(device)
models = Models({"G": G, "C": C, "D": D})
optimizers = Optimizers((torch.optim.Adam, {"lr": 0.0001}))
optimizers.create_with(models)
optimizers = list(optimizers.values())

hook = DANNHook(optimizers)
validator = IMValidator()

No pretrained weights exist for this model. Using random initialization.


In [8]:
src_train_dataset = SourceDataset(src_train_dataloader.dataset)
src_test_dataset = SourceDataset(src_test_dataloader.dataset)
tgt_train_dataset = TargetDataset(tgt_train_dataloader.dataset)
tgt_test_dataset = TargetDataset(tgt_test_dataloader.dataset)

In [9]:
custom_datasets = {
    "src_train": src_train_dataset,
    "src_val": src_test_dataset,
    "target_train": tgt_train_dataset,
    "target_val": tgt_test_dataset,
    "train": CombinedSourceAndTargetDataset(source_dataset=src_train_dataset, target_dataset=tgt_train_dataset),
}

In [10]:
dc = DataloaderCreator(batch_size=32, num_workers=2)
dataloaders = dc(**custom_datasets)

In [11]:
for epoch in range(2):

    # train loop
    models.train()
    for data in tqdm(dataloaders["train"]):
        data = batch_to_device(data, device)
        loss, _ = hook({}, {**models, **data})

    # eval loop
    models.eval()
    logits = []
    with torch.no_grad():
        for data in tqdm(dataloaders["target_train"]):
            data = batch_to_device(data, device)
            logits.append(C(G(data["target_imgs"])))
        logits = torch.cat(logits, dim=0)

    # validation score
    score = validator.score(target_train={"logits": logits})
    print(f"Epoch {epoch} score = {score}")

100%|██████████| 12/12 [00:10<00:00,  1.18it/s]
100%|██████████| 13/13 [00:05<00:00,  2.32it/s]
  0%|          | 0/12 [00:00<?, ?it/s]

Epoch 0 score = 4.76837158203125e-07


100%|██████████| 12/12 [00:09<00:00,  1.32it/s]
100%|██████████| 13/13 [00:05<00:00,  2.35it/s]

Epoch 1 score = 0.026787281036376953





In [48]:
dataloaders["train"].dataset

CombinedSourceAndTargetDataset(
  (source_dataset): SourceDataset(
    domain=0
    (dataset): ConcatDataset(
      len=60000
      (datasets): [Dataset MNIST
          Number of datapoints: 60000
          Root location: .
          Split: Train
          StandardTransform
      Transform: Compose(
                     Resize(size=32, interpolation=bilinear)
                     ToTensor()
                     <pytorch_adapt.utils.transforms.GrayscaleToRGB object at 0x0000021387F6FF08>
                     Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                 )]
    )
  )
  (target_dataset): TargetDataset(
    domain=1
    (dataset): ConcatDataset(
      len=59001
      (datasets): [MNISTM(
        domain=MNISTM
        len=59001
        (transform): Compose(
            ToTensor()
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        )
      )]
    )
  )
)

In [41]:
datasets

{'src_train': SourceDataset(
   domain=0
   (dataset): ConcatDataset(
     len=60000
     (datasets): [Dataset MNIST
         Number of datapoints: 60000
         Root location: .
         Split: Train
         StandardTransform
     Transform: Compose(
                    Resize(size=32, interpolation=bilinear)
                    ToTensor()
                    <pytorch_adapt.utils.transforms.GrayscaleToRGB object at 0x00000213882897C8>
                    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                )]
   )
 ),
 'src_val': SourceDataset(
   domain=0
   (dataset): ConcatDataset(
     len=10000
     (datasets): [Dataset MNIST
         Number of datapoints: 10000
         Root location: .
         Split: Test
         StandardTransform
     Transform: Compose(
                    Resize(size=32, interpolation=bilinear)
                    ToTensor()
                    <pytorch_adapt.utils.transforms.GrayscaleToRGB object at 0x00000213884177C8>
      