In [None]:
!pip install pytorch-adapt

### Create Datasets

In [None]:
from torchvision.datasets import FakeData
from torchvision.transforms import ToTensor

x = FakeData(size=320, transform=ToTensor())
y = FakeData(size=320, transform=ToTensor())

### Dataset Wrappers

These wrappers transform datasets so that they are compatible with Adapters and Hooks.

Notice that ```CombinedSourceAndTargetDataset``` returns the target sample corresponding with the input index, but returns a random source sample, even with the same input index.

In [None]:
from pytorch_adapt.datasets import (
    CombinedSourceAndTargetDataset,
    SourceDataset,
    TargetDataset,
)

src = SourceDataset(x)
target = TargetDataset(y)
print("SourceDataset", src)
print(src[0].keys())

print("\nTargetDataset", target)
print(target[0].keys())

src_target = CombinedSourceAndTargetDataset(src, target)
print("\nCombinedSourceAndTarget", src_target)
for _ in range(2):
    retrieved = src_target[0]
    print("src index", retrieved["src_sample_idx"])
    print("target_index", retrieved["target_sample_idx"])

### Using CombinedSourceAndTargetDataset as input to hooks

In [None]:
import torch

from pytorch_adapt.hooks import FeaturesHook

models = {"G": torch.nn.Conv2d(3, 32, 3)}
dataloader = torch.utils.data.DataLoader(src_target, batch_size=32)
hook = FeaturesHook()

for data in dataloader:
    outputs, losses = hook({**models, **data})
    print(outputs.keys())
    break

### DataloaderCreator

```DataloaderCreator``` is a factory class. It allows you to specify how dataloaders should be made for multiple datasets.

In [None]:
from pytorch_adapt.datasets import DataloaderCreator


def print_dataloaders(dataloaders):
    print({k: (v.batch_size, v.num_workers) for k, v in dataloaders.items()})


# set the batch_size and num_workers for all datasets
dc = DataloaderCreator(batch_size=64, num_workers=2)
dataloaders = dc(train=src_target, src_train=src, target_train=target)
print_dataloaders(dataloaders)

# set different params for train vs val datasets
dc = DataloaderCreator(
    train_kwargs={"batch_size": 64, "num_workers": 2},
    val_kwargs={"batch_size": 256, "num_workers": 4},
)
dataloaders = dc(train=src_target, src_val=src, target_val=target)
print_dataloaders(dataloaders)

# specify the name of the validation datasets
dc = DataloaderCreator(
    val_kwargs={"batch_size": 256, "num_workers": 4}, val_names=["val1", "val2"]
)
dataloaders = dc(train=src_target, val1=src, val2=target)
print_dataloaders(dataloaders)

# consider all inputs to be validation datasets
dc = DataloaderCreator(val_kwargs={"batch_size": 256, "num_workers": 4}, all_val=True)
dataloaders = dc(train=src_target, val=src, woof=target)
print_dataloaders(dataloaders)