In [1]:
!pip install pytorch-adapt

Collecting pytorch-adapt
  Downloading pytorch_adapt-0.0.61-py3-none-any.whl (137 kB)
[K     |████████████████████████████████| 137 kB 5.5 MB/s 
Collecting torchmetrics
  Downloading torchmetrics-0.7.2-py3-none-any.whl (397 kB)
[K     |████████████████████████████████| 397 kB 12.2 MB/s 
Collecting pytorch-metric-learning>=1.1.0
  Downloading pytorch_metric_learning-1.2.0-py3-none-any.whl (107 kB)
[K     |████████████████████████████████| 107 kB 6.7 MB/s 
Collecting pyDeprecate==0.3.*
  Downloading pyDeprecate-0.3.2-py3-none-any.whl (10 kB)
Installing collected packages: pyDeprecate, torchmetrics, pytorch-metric-learning, pytorch-adapt
Successfully installed pyDeprecate-0.3.2 pytorch-adapt-0.0.61 pytorch-metric-learning-1.2.0 torchmetrics-0.7.2


### Create Datasets

In [2]:
from torchvision.datasets import FakeData
from torchvision.transforms import ToTensor

x = FakeData(size=320, transform=ToTensor())
y = FakeData(size=320, transform=ToTensor())

### Dataset Wrappers

These wrappers transform datasets so that they are compatible with Adapters and Hooks.

Notice that ```CombinedSourceAndTargetDataset``` returns the target sample corresponding with the input index, but returns a random source sample, even with the same input index.

In [3]:
from pytorch_adapt.datasets import (
    CombinedSourceAndTargetDataset,
    SourceDataset,
    TargetDataset,
)

src = SourceDataset(x)
target = TargetDataset(y)
print("SourceDataset", src)
print(src[0].keys())

print("\nTargetDataset", target)
print(target[0].keys())

src_target = CombinedSourceAndTargetDataset(src, target)
print("\nCombinedSourceAndTarget", src_target)
for _ in range(2):
    retrieved = src_target[0]
    print("src index", retrieved["src_sample_idx"])
    print("target_index", retrieved["target_sample_idx"])

SourceDataset SourceDataset(
  domain=0
  (dataset): Dataset FakeData
      Number of datapoints: 320
      StandardTransform
  Transform: ToTensor()
)
dict_keys(['src_imgs', 'src_domain', 'src_labels', 'src_sample_idx'])

TargetDataset TargetDataset(
  domain=1
  (dataset): Dataset FakeData
      Number of datapoints: 320
      StandardTransform
  Transform: ToTensor()
)
dict_keys(['target_imgs', 'target_domain', 'target_sample_idx'])

CombinedSourceAndTarget CombinedSourceAndTargetDataset(
  (source_dataset): SourceDataset(
    domain=0
    (dataset): Dataset FakeData
        Number of datapoints: 320
        StandardTransform
    Transform: ToTensor()
  )
  (target_dataset): TargetDataset(
    domain=1
    (dataset): Dataset FakeData
        Number of datapoints: 320
        StandardTransform
    Transform: ToTensor()
  )
)
src index 20
target_index 0
src index 252
target_index 0


### Using CombinedSourceAndTargetDataset as input to hooks

In [4]:
import torch

from pytorch_adapt.hooks import FeaturesHook

models = {"G": torch.nn.Conv2d(3, 32, 3)}
dataloader = torch.utils.data.DataLoader(src_target, batch_size=32)
hook = FeaturesHook()

for data in dataloader:
    outputs, losses = hook({**models, **data})
    print(outputs.keys())
    break

dict_keys(['src_imgs_features', 'target_imgs_features'])


### DataloaderCreator

```DataloaderCreator``` is a factory class. It allows you to specify how dataloaders should be made for multiple datasets.

In [5]:
from pytorch_adapt.datasets import DataloaderCreator


def print_dataloaders(dataloaders):
    print({k: (v.batch_size, v.num_workers) for k, v in dataloaders.items()})


# set the batch_size and num_workers for all datasets
dc = DataloaderCreator(batch_size=64, num_workers=2)
dataloaders = dc(train=src_target, src_train=src, target_train=target)
print_dataloaders(dataloaders)

# set different params for train vs val datasets
dc = DataloaderCreator(
    train_kwargs={"batch_size": 64, "num_workers": 2},
    val_kwargs={"batch_size": 256, "num_workers": 4},
)
dataloaders = dc(train=src_target, src_val=src, target_val=target)
print_dataloaders(dataloaders)

# specify the name of the validation datasets
dc = DataloaderCreator(
    val_kwargs={"batch_size": 256, "num_workers": 4}, val_names=["val1", "val2"]
)
dataloaders = dc(train=src_target, val1=src, val2=target)
print_dataloaders(dataloaders)

# consider all inputs to be validation datasets
dc = DataloaderCreator(val_kwargs={"batch_size": 256, "num_workers": 4}, all_val=True)
dataloaders = dc(train=src_target, val=src, woof=target)
print_dataloaders(dataloaders)

{'train': (64, 2), 'src_train': (64, 2), 'target_train': (64, 2)}
{'train': (64, 2), 'src_val': (256, 4), 'target_val': (256, 4)}
{'train': (32, 0), 'val1': (256, 4), 'val2': (256, 4)}
{'train': (256, 4), 'val': (256, 4), 'woof': (256, 4)}


  cpuset_checked))
