In [8]:
from pathlib import Path
from sklearn.manifold import TSNE
from glob import glob
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_theme()

import torch

from pytorch_adapt.utils.common_functions import batch_to_device

from adapt.config import Arguments
from adapt.trainer import Trainer

In [9]:
domains = ["Art", "Clipart", "Product", "Real World"]
args = Arguments()
args.src_data = 2
args.tgt_data = 1
args.model_name_or_path = "resmlp_12_distilled_224"
args.embed_dim = 1000
args.dir_name = f"{args.model_name_or_path}_src({domains[args.src_data]})_tgt({domains[args.tgt_data]})"
model_list = sorted(glob(f"./outputs/{args.dir_name}/*"))

trainer = Trainer(args)

[12/05/2021 14:37:32] INFO - adapt.trainer: Load PyTorch-Adapt Dataset.
[12/05/2021 14:37:32] INFO - adapt.trainer: Successfully loaded PyTorch-Adapt Dataset of source=Product and target=Clipart
[12/05/2021 14:37:32] INFO - adapt.trainer: Setup model.
[12/05/2021 14:37:33] INFO - timm.models.helpers: Loading pretrained weights from url (https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_100-35495452.pth)
[12/05/2021 14:37:33] INFO - adapt.trainer: Successfully setup model.


In [10]:
%matplotlib

src_name = domains[args.src_data]
tgt_name = domains[args.tgt_data]

fig, ax = plt.subplots(nrows=4, ncols=5)

for idx, epoch in enumerate(model_list):

    row, col = idx // 5, idx % 5

    title = "Epoch" + " ".join(epoch.split("\\")[-1].split("_"))
    trainer.load_model(glob(f"{epoch}/*.pt"))

    src_embed = []
    for data in tqdm(trainer.dataloaders["src_val"]):
        data = batch_to_device(data, trainer.device)
        embed = trainer.models["G"](data["src_imgs"]).detach().cpu()
        src_embed.append(embed)

        del data
        torch.cuda.empty_cache()

    src_embed = torch.cat(src_embed, dim=0).numpy()

    tgt_embed = []
    for data in tqdm(trainer.dataloaders["target_val"]):
        data = batch_to_device(data, trainer.device)
        embed = trainer.models["G"](data["target_imgs"]).detach().cpu()
        tgt_embed.append(embed)

        del data
        torch.cuda.empty_cache()
        
    tgt_embed = torch.cat(tgt_embed, dim=0).numpy()

    tsne = TSNE()

    src_tsne = tsne.fit_transform(src_embed)
    tgt_tsne = tsne.fit_transform(tgt_embed)

    ax[row][col].scatter(src_tsne[:, 0], src_tsne[:, 1], color="r", label=f"Source: {src_name}", alpha=.7)
    ax[row][col].scatter(tgt_tsne[:, 0], tgt_tsne[:, 1], color="b", label=f"Target: {tgt_name}", alpha=.7)
    ax[row][col].legend()
    ax[row][col].set_title(title)
plt.show()

Using matplotlib backend: Qt5Agg


  0%|          | 0/4 [00:00<?, ?it/s]

G is successfully loaded
C is successfully loaded
D is successfully loaded


100%|██████████| 4/4 [00:02<00:00,  1.48it/s]
100%|██████████| 4/4 [00:03<00:00,  1.31it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

G is successfully loaded
C is successfully loaded
D is successfully loaded


100%|██████████| 4/4 [00:02<00:00,  1.52it/s]
100%|██████████| 4/4 [00:02<00:00,  1.57it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

G is successfully loaded
C is successfully loaded
D is successfully loaded


100%|██████████| 4/4 [00:02<00:00,  1.47it/s]
100%|██████████| 4/4 [00:02<00:00,  1.52it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

G is successfully loaded
C is successfully loaded
D is successfully loaded


100%|██████████| 4/4 [00:02<00:00,  1.53it/s]
100%|██████████| 4/4 [00:02<00:00,  1.62it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

G is successfully loaded
C is successfully loaded
D is successfully loaded


100%|██████████| 4/4 [00:02<00:00,  1.54it/s]
100%|██████████| 4/4 [00:02<00:00,  1.62it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

G is successfully loaded
C is successfully loaded
D is successfully loaded


100%|██████████| 4/4 [00:02<00:00,  1.55it/s]
100%|██████████| 4/4 [00:02<00:00,  1.62it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

G is successfully loaded
C is successfully loaded
D is successfully loaded


100%|██████████| 4/4 [00:02<00:00,  1.56it/s]
100%|██████████| 4/4 [00:02<00:00,  1.60it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

G is successfully loaded
C is successfully loaded
D is successfully loaded


100%|██████████| 4/4 [00:02<00:00,  1.48it/s]
100%|██████████| 4/4 [00:02<00:00,  1.56it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

G is successfully loaded
C is successfully loaded
D is successfully loaded


100%|██████████| 4/4 [00:02<00:00,  1.53it/s]
100%|██████████| 4/4 [00:02<00:00,  1.51it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

G is successfully loaded
C is successfully loaded
D is successfully loaded


100%|██████████| 4/4 [00:02<00:00,  1.54it/s]
100%|██████████| 4/4 [00:02<00:00,  1.60it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

G is successfully loaded
C is successfully loaded
D is successfully loaded


100%|██████████| 4/4 [00:02<00:00,  1.39it/s]
100%|██████████| 4/4 [00:02<00:00,  1.52it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

G is successfully loaded
C is successfully loaded
D is successfully loaded


100%|██████████| 4/4 [00:02<00:00,  1.46it/s]
100%|██████████| 4/4 [00:02<00:00,  1.59it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

G is successfully loaded
C is successfully loaded
D is successfully loaded


100%|██████████| 4/4 [00:02<00:00,  1.51it/s]
100%|██████████| 4/4 [00:02<00:00,  1.59it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

G is successfully loaded
C is successfully loaded
D is successfully loaded


100%|██████████| 4/4 [00:02<00:00,  1.50it/s]
100%|██████████| 4/4 [00:02<00:00,  1.57it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

G is successfully loaded
C is successfully loaded
D is successfully loaded


100%|██████████| 4/4 [00:02<00:00,  1.48it/s]
100%|██████████| 4/4 [00:02<00:00,  1.45it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

G is successfully loaded
C is successfully loaded
D is successfully loaded


100%|██████████| 4/4 [00:02<00:00,  1.54it/s]
100%|██████████| 4/4 [00:02<00:00,  1.60it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

G is successfully loaded
C is successfully loaded
D is successfully loaded


100%|██████████| 4/4 [00:02<00:00,  1.47it/s]
100%|██████████| 4/4 [00:02<00:00,  1.53it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

G is successfully loaded
C is successfully loaded
D is successfully loaded


100%|██████████| 4/4 [00:02<00:00,  1.47it/s]
100%|██████████| 4/4 [00:02<00:00,  1.54it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

G is successfully loaded
C is successfully loaded
D is successfully loaded


100%|██████████| 4/4 [00:02<00:00,  1.53it/s]
100%|██████████| 4/4 [00:02<00:00,  1.58it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

G is successfully loaded
C is successfully loaded
D is successfully loaded


100%|██████████| 4/4 [00:02<00:00,  1.54it/s]
100%|██████████| 4/4 [00:02<00:00,  1.56it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

G is successfully loaded
C is successfully loaded
D is successfully loaded


100%|██████████| 4/4 [00:02<00:00,  1.53it/s]
100%|██████████| 4/4 [00:02<00:00,  1.53it/s]


IndexError: index 4 is out of bounds for axis 0 with size 4

In [5]:
trainer.load_model(glob(f"{model_list[0]}/*.pt"))

G is successfully loaded
C is successfully loaded
D is successfully loaded


In [6]:
src_embed = []
for data in tqdm(trainer.dataloaders["src_val"]):
    data = batch_to_device(data, trainer.device)
    embed = trainer.models["G"](data["src_imgs"]).detach().cpu()
    src_embed.append(embed)

    del data
    torch.cuda.empty_cache()

src_embed = torch.cat(src_embed, dim=0).numpy()

tgt_embed = []
for data in tqdm(trainer.dataloaders["target_val"]):
    data = batch_to_device(data, trainer.device)
    embed = trainer.models["G"](data["target_imgs"]).detach().cpu()
    tgt_embed.append(embed)

    del data
    torch.cuda.empty_cache()
    
tgt_embed = torch.cat(tgt_embed, dim=0).numpy()

100%|██████████| 4/4 [00:05<00:00,  1.28s/it]
100%|██████████| 4/4 [00:02<00:00,  1.42it/s]


In [9]:
tsne = TSNE()

src_tsne = tsne.fit_transform(src_embed)
tgt_tsne = tsne.fit_transform(tgt_embed)

In [15]:
src_tsne[:, 0].shape

(113,)

In [17]:
!pip install matplotlib



In [25]:
%matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_theme()

src_name = domains[args.src_data]
tgt_name = domains[args.tgt_data]

plt.scatter(src_tsne[:, 0], src_tsne[:, 1], color="r", label=f"Source: {src_name}", alpha=.7)
plt.scatter(tgt_tsne[:, 0], tgt_tsne[:, 1], color="b", label=f"Target: {tgt_name}", alpha=.7)
plt.legend()
plt.show()

Using matplotlib backend: Qt5Agg
