### Install package

In [None]:
!pip install pytorch-adapt[ignite] seaborn pandas umap-learn

### Import packages

In [None]:
import logging

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import torch
import umap
from ignite.engine import Events

from pytorch_adapt.adapters import DANN
from pytorch_adapt.containers import Models, Optimizers
from pytorch_adapt.datasets import DataloaderCreator, get_mnist_mnistm
from pytorch_adapt.frameworks.ignite import Ignite
from pytorch_adapt.models import Discriminator, mnistC, mnistG
from pytorch_adapt.validators import AccuracyValidator, IMValidator

logging.basicConfig()
logging.getLogger("pytorch-adapt").setLevel(logging.INFO)

In [None]:
def get_viz_fn(trainer, dc, datasets):
    def viz(_):
        features, domain = [], []

        for x in ["src_val", "target_val"]:
            dataloader = dc(**{x: datasets[x]})[x]
            output = trainer.get_all_outputs(dataloader, x)
            features.append(output[x]["features"])
            domain.append(output[x]["domain"])

        features = torch.cat(features, dim=0).cpu().numpy()
        domain = torch.cat(domain, dim=0).cpu().numpy()
        emb = umap.UMAP().fit_transform(features)

        df = pd.DataFrame(emb).assign(domain=domain)
        df["domain"] = df["domain"].replace({0: "Source", 1: "Target"})
        sns.set_theme(style="white", rc={"figure.figsize": (12.8, 9.6)})
        sns.scatterplot(data=df, x=0, y=1, hue="domain", s=2)
        plt.show()

    return viz

### Create datasets and dataloaders

In [None]:
datasets = get_mnist_mnistm(["mnist"], ["mnistm"], folder=".", download=True)
dc = DataloaderCreator(batch_size=32, num_workers=2)

### Create models, optimizers, hook, and validator

In [None]:
G = mnistG(pretrained=True)
C = mnistC(pretrained=True)
D = Discriminator(in_size=1200, h=256)
models = Models({"G": G, "C": C, "D": D})
optimizers = Optimizers((torch.optim.Adam, {"lr": 0.0001}))

adapter = DANN(models=models, optimizers=optimizers)
trainer = Ignite(adapter, validator=IMValidator(), stat_getter=AccuracyValidator())

### Attach visualization function

In [None]:
viz_condition = Events.EPOCH_COMPLETED(every=2) | Events.STARTED
viz_fn = get_viz_fn(trainer, dc, datasets)
trainer.trainer.add_event_handler(viz_condition, viz_fn)

### Train and evaluate

In [None]:
best_score, best_epoch = trainer.run(datasets, dataloader_creator=dc, max_epochs=4)
print(f"best_score={best_score}, best_epoch={best_epoch}")