In [None]:
!pip install pytorch-adapt

### Datasets Source and Target Datasets

In [None]:
from torchvision.datasets import MNIST

from pytorch_adapt.datasets import (
    MNISTM,
    CombinedSourceAndTargetDataset,
    SourceDataset,
    TargetDataset,
)

x = MNIST(root=".", train=True, transform=None)
y = MNISTM(root=".", train=True, transform=None)
# x and y return (data, label) tuples
print(x[0])
print(y[0])

x = SourceDataset(x)
y = TargetDataset(y)
# x and y return dictionaries
print(x[0])
print(y[0])

xy = CombinedSourceAndTargetDataset(x, y)
# xy returns a dictionary
print(xy[0])

### Datasets Getters and DataloaderCreator

In [None]:
from pytorch_adapt.datasets import DataloaderCreator, get_mnist_mnistm

datasets = get_mnist_mnistm(["mnist"], ["mnistm"], folder=".")
dc = DataloaderCreator(batch_size=128)
dataloaders = dc(**datasets)

# datasets and dataloaders are dictionaries
print(datasets)
print(dataloaders)

### Hooks Computing Features

In [None]:
from pytorch_adapt.hooks import FeaturesHook

G = torch.nn.Linear(1000, 100)
models = {"G": G}
data = {
    "src_imgs": torch.randn(32, 1000),
    "target_imgs": torch.randn(32, 1000),
}

hook = FeaturesHook()

losses, outputs = hook({}, {**models, **data})
# outputs contains src_imgs_features and target_imgs_features
print(outputs.keys())

losses, outputs = hook({}, {**models, **data, **outputs})
# outputs is empty
print(outputs.keys())

hook = FeaturesHook(detach=True)
losses, outputs = hook({}, {**models, **data, **outputs})
# outputs contains
# src_imgs_features_detached and target_imgs_features_detached
print(outputs.keys())

### Weighters

In [None]:
import torch

from pytorch_adapt.weighters import MeanWeighter

weighter = MeanWeighter(weights={"y": 2.3})

logits = torch.randn(32, 512, requires_grad=True)
labels = torch.randint(0, 10, size=(32,))

x = torch.nn.functional.cross_entropy(logits, labels)
y = torch.norm(logits)

# y will by multiplied by 2.3
# x wasn't given a weight,
# so it gets multiplied by the default value of 1.
loss, components = weighter({"x": x, "y": y})
loss.backward()